import argparse
import copy
import gymnasium as gym
import json
import logging
import math
import numpy as np
import os
import pathlib
import re
import scipy.stats
import sys
import time
import traceback

from collections import deque, namedtuple
from typing import Any, List, Tuple, Dict, Optional

import latency_env.misc.argparser_types as at
from latency_env.misc import Argument as Arg
from latency_env.misc import ArgumentList as ArgList

LOG = logging.getLogger(__name__)
LOG.addHandler(logging.NullHandler())


class StationaryActionClass:
    def __eq__(self, other): return isinstance(other, StationaryActionClass)
    def __str__(self): return f" {chr(0x2205)} " # empty set unicode
    def __repr__(self): return f"<{self.__str__()}>"
StationaryAction = StationaryActionClass()


ExtendedState = namedtuple("ExtendedState", [
    "t",          # timestep that this was applied at
    "t_origin",   # timestep that this data originates from
    "delay",      # the delay of the generated actions
    "delayshift", # how much that the action was shifted
    "s_obs",      # observed state
    "a_mem",      # actions in the action buffer
])


class DelayProcess(gym.utils.EzPickle):
    """Base class for random delay processes."""
    def __init__(self, tag=None):
        self._tag = tag

    def __post_init__(self):
        assert self._minsample() > 0

    @property
    def tag(self):
        if self._tag is not None:
            return str(self._tag)
        else:
            return self.__str__()

    def reset(self):
        """Resets the delay process internal state."""
        pass

    def sample(self) -> int:
        """Samples a delay."""
        raise NotImplementedError

    def distribution(self) -> Dict[int, float]:
        """
        Return a distribution over delays. In not specially implemented then we
        just perform a Monte Carlo estimation.
        """
        LOG.warning(f"Sampling distribution for delay {type(self).__name__}")
        buckets = {}
        for _ in range(5000):
            d = self.sample()
            c = buckets.get(d, 0) + 1
            buckets[d] = c
        return {d: float(c) / 5000.0 for d, c in buckets.items()}

    @property
    def min(self):
        return self._minsample()

    def __str__(self): raise NotImplementedError
    def __int__(self): raise NotImplementedError

    def _minsample(self):
        """Return the minimum sample value for this distribution."""
        raise NotImplementedError

    def _partialcmp(self, other):
        """
        Partial ordering of a delay process
            -1 if self < other
             0 if self = other
             1 if self > other
          None if unrelated
        """
        return None

    def __eq__(self, other): return bool(self._partialcmp(other) == 0)
    def __ne__(self, other): return bool(self._partialcmp(other) != 0)
    def __lt__(self, other): return bool(self._partialcmp(other) == -1)
    def __le__(self, other): return bool(self._partialcmp(other) in [-1, 0])
    def __gt__(self, other): return bool(self._partialcmp(other) == 1)
    def __ge__(self, other): return bool(self._partialcmp(other) in [0, 1])


class DatasetDelay(DelayProcess):
    """Replay of a measured delay process."""
    def __init__(self, path, qms=8, **kwargs):
        gym.utils.EzPickle.__init__(self, path, qms=qms, **kwargs)
        self.qms = qms
        self.path = path
        with open(path) as f:
            measurements = json.load(f)
        self.delays = np.array([int(math.ceil(m["latency_ms"] / self.qms)) for m in measurements])
        assert len(self.delays) > 0
        self.delay_idx = 0
        self.reset()

    def reset(self):
        """Revert to the initial state, and reset all contained delays."""
        self.delay_idx = np.random.randint(0, len(self.delays))

    def sample(self):
        d = self.delays[self.delay_idx]
        self.delay_idx += 1
        if self.delay_idx >= len(self.delays):
            self.delay_idx = 0
        return int(d)

    def distribution(self):
        ticks = {}
        for d in self.delays:
            d = int(d)
            ticks[d] = ticks.get(d, 0.0) + 1.0

        for d in list(ticks.keys()):
            ticks[d] = ticks[d] / len(self.delays)

        return ticks

    def _minsample(self):
        return int(self.delays.min())

    def __str__(self):
        return f"DatasetDelay({self.path})"



class ConstantDelay(DelayProcess):
    """A constant delay process."""
    def __init__(self, constant : int, **kwargs):
        gym.utils.EzPickle.__init__(self, constant, **kwargs)
        self.constant = int(constant)

    def sample(self):
        return self.constant

    def distribution(self):
        return {self.constant: 1.0}

    def _minsample(self):
        return self.constant

    def _partialcmp(self, other):
        c = None
        if isinstance(other, ConstantDelay): c = other.constant
        elif isinstance(other, int):         c = other
        if c is not None:
            if self.constant < c: return -1
            if self.constant > c: return 1
            else:                 return 0
        return super()._partialcmp(other)

    def __str__(self):
        return f"ConstantDelay{self.constant}"

    def __int__(self):
        return self.constant


class RandomDiscreteDelay(DelayProcess):
    """A random delay process whos samples are all discrete and independent."""
    def __init__(self, dist : scipy.stats.rv_discrete, shift=1, **kwargs):
        super().__init__(**kwargs)
        self.dist = dist
        self.shift = shift

    def sample(self):
        return self.shift + max(0, int(self.dist.rvs()))

    def _minsample(self):
        return self.shift + int(max(0, self.dist.a))

    def _partialcmp(self, other):
        if isinstance(other, int):
            if other < self.min:
                return -1
        return super()._partialcmp(other)


class RandomRoundedDelay(DelayProcess):
    """A random delay process whos samples are rounded to the nearest integer."""
    def __init__(self, dist : scipy.stats.rv_continuous, shift=1, **kwargs):
        super().__init__(**kwargs)
        self.dist = dist
        self.shift = shift

    def sample(self):
        return self.shift + max(0, int(math.ceil(self.dist.rvs())))

    def distribution(self):
        ticks = {}
        c_prev = 0.0
        for t in range(1000):
            c = self.dist.cdf(t)
            if t > self.dist.mean() and c > (1 - 1e-5):
                break
            elif c > 1e-5:
                ticks[self.shift + t] = c - c_prev
            c_prev = c
        return ticks

    def _minsample(self):
        return self.shift + int(max(0, self.dist.a))

    def _partialcmp(self, other):
        if isinstance(other, int):
            if other < self.min:
                return -1
        return super()._partialcmp(other)

    def __str__(self):
        return "-".join(["RandomRoundedDelay", f"{self.dist.dist.name}"] + [f"{k}{v}" for k, v in self.dist.kwds.items()])


class RandomCategoricalDelay(DelayProcess):
    """A random delay process whos samples are all categorically distributed."""
    def __init__(self, weights, shift=1, **kwargs):
        super().__init__(**kwargs)
        gym.utils.EzPickle.__init__(self, weights, shift)
        weights = np.array(weights, dtype=np.float64)
        self.weights = weights / weights.sum()
        self.dist = scipy.stats.multinomial(1, self.weights)
        self.shift = shift

    def sample(self):
        return self.shift + self.dist.rvs(1).argmax(1).item()

    def distribution(self):
        return {
            self.shift + i: self.weights[i]
            for i, w in enumerate(self.weights)
        }

    def _minsample(self):
        return self.shift

    def __str__(self):
        return "Categorical(" + ",".join(
            [f"{w:.3}" for w in self.weights] + (
                [] if self.shift == 1 else [f"shift={self.shift}"]
            )
        ) + ")"


class HiddenMarkovianDelay(DelayProcess):
    def __init__(self,
                 delays: List[DelayProcess],
                 transition_matrix: np.ndarray,
                 initial_state: int = 0,
                 **kwargs):
        super().__init__(**kwargs)
        gym.utils.EzPickle.__init__(self, delays, transition_matrix, initial_state)
        self.delays = delays
        self.initial_state = initial_state
        assert len(self.delays) > 0
        assert self.initial_state in range(len(self.delays))
        assert all(isinstance(d, DelayProcess) for d in self.delays)

        L = len(self.delays)
        trmx = np.array(transition_matrix, dtype=np.float64)
        self.trmx = trmx / trmx.sum(axis=1).reshape(L, 1).repeat(L, 1)
        assert self.trmx.shape == (len(self.delays), len(self.delays))

        # Cumulative PMF for transition matrix
        self.cumtrmx = self.trmx.cumsum(axis=1)

        self._current_state = self.initial_state

    def reset(self):
        """Revert to the initial state, and reset all contained delays."""
        for d in self.delays:
            d.reset()
        self._current_state = self.initial_state

    def sample(self):
        sd = self.delays[self._current_state].sample()

        # Sample next state
        nxt = (np.random.rand() < self.cumtrmx[self._current_state]).argmax().item()
        self._current_state = nxt

        return sd

    def distribution(self):
        # Compute the steady/stationary state (assuming we can reach any state)
        itertrmx = self.trmx
        for i in range(1000):
            new_itertrmx = itertrmx @ self.trmx
            if np.isclose(itertrmx, new_itertrmx).all():
                LOG.debug(f"Markov distribution converged after {i} iterations")
                itertrmx = new_itertrmx
                break
            itertrmx = new_itertrmx

        ticks = {}
        for i, d in enumerate(self.delays):
            dticks = d.distribution()
            for t, v in dticks.items():
                new_v = ticks.get(t, 0.0) + (v * itertrmx[0, i])
                ticks[t] = new_v

        return ticks

    def _minsample(self):
        return min([d._minsample() for d in self.delays])

    def __str__(self):
        return "".join([
            "Markovian(",
            ", ".join([str(d) for d in self.delays] + [str(self.trmx.tolist())]),
            ")"
        ])


class MM1QueueDelay(DelayProcess):
    """
    This models an M/M/1 Queue where it in the best case only has a delay of 1,
    and it increases with the arrival times.

    The delay is the service time rounded up to the nearest whole number.
    """
    def __init__(self,
                 rate_arrive: float,
                 rate_service: float,
                 **kwargs):
        super().__init__(**kwargs)
        gym.utils.EzPickle.__init__(self, rate_arrive, rate_service)
        self.rate_arrive = rate_arrive
        self.rate_service = rate_service
        assert self.rate_arrive > 0, f"Arrival rate must be a positive value, got {self.rate_arrive}"
        assert self.rate_service > 0, f"Service rate must be a positive value, got {self.rate_service}"
        assert self.rate_arrive < self.rate_service, f"Service rate ({self.rate_service}) is not greater than arrival rate ({self.rate_arrive}), system is unstable."

        # From the documentation: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.expon.html#scipy.stats.expon
        # Need to use 1/lambda as the scale
        self._dist_arrive = scipy.stats.expon(scale=1 / self.rate_arrive)
        self._dist_service = scipy.stats.expon(scale=1 / self.rate_service)

        self.reset()

    @property
    def rho(self): return (self.rate_arrive / self.rate_service)

    def reset(self):
        """
        Revert the delay process to the initial state.

        The queue _q contains packet insertion times. The delay is the
        difference between time of insertion to time of service for a
        packet.
        """
        self._q = deque()
        self._next_arrival_time = self._dist_arrive.rvs().item()
        self._next_service_time = None

    def sample(self):
        if self._next_service_time is None:
            assert len(self._q) == 0
            t = self._next_arrival_time
            self._q.appendleft(t)
            # Setup next arrival and service time
            self._next_arrival_time = t + self._dist_arrive.rvs().item()
            self._next_service_time = t + self._dist_service.rvs().item()

        # Accept packets until we can serve
        while self._next_arrival_time < self._next_service_time:
            assert len(self._q) > 0
            t = self._next_arrival_time
            self._q.appendleft(t)
            self._next_arrival_time += self._dist_arrive.rvs().item()

        # Now it is time to service
        t = self._next_service_time
        t_inserted = self._q.pop()
        d = int(math.ceil(t - t_inserted))

        if len(self._q) > 0:
            self._next_service_time += self._dist_service.rvs().item()
        else:
            self._next_service_time = None

        return d

    def distribution(self):
        # Find when the probability becomes too small to care about
        # i.e. rho^k ~= 0.001
        #      k ~= log(0.001)/log(rho)
        #max_i = int(math.log(0.001) / math.log(self.rho)) + 1

        #ticks = {
        #    i: (1 - self.rho) * (self.rho ** i)
        #    for i in range(0, max_i)
        #}

        # Actually, we need to measure the busy time of the server instead...

        # For now, just simulate...
        return super().distribution()

    def _minsample(self):
        return 1

    def __str__(self):
        return "".join([
            f"MM1Queue({self.rate_arrive}, {self.rate_service}",
            #"" if self.maximum_state == 0 else f", maximum_state={self.maximum_state}",
            ")",
        ])


def _partial_delay_from_string(s) -> Tuple[DelayProcess, str]:
    s = s.strip()
    num_match = re.match(r"^\d+", s)
    if num_match is not None:
        (i1, i2) = num_match.span()
        return (ConstantDelay(int(s[i1:i2])), s[i2:])

    if s.find("(") > 0:
        li, ri = (s.find("("), s.find(")"))
        assert li < ri
        fname = s[:li].lower().strip()
        if fname in ["norm", "normal", "gaussian", "gauss"]:
            args = [a.strip() for a in s[li+1:ri].split(",")]
            if len(args) not in [2, 3]:
                raise ValueError("Expected 2 or 3 arguments to normal delay distribution")
            args.append(1) # default shift
            (loc, scale, shift) = (float(args[0]), float(args[1]), int(args[2]))
            return (RandomRoundedDelay(scipy.stats.norm(loc, scale), shift=shift), s[ri+1:])
        elif fname in ["weighted", "weight", "categorical", "cat", "ordinal", "ord"]:
            args = [a.strip() for a in s[li+1:ri].split(",")]
            if len(args) > 0 and args[-1].startswith("shift"):
                shiftarg = args[-1]
                args = args[:-1]
                eqidx = shiftarg.find("=")
                assert eqidx > 0
                shift = int(shiftarg[eqidx+1:].strip())
            else:
                shift = 1
            assert len(args) > 0, f"Must have at least 1 weight"
            weights = [float(w) for w in args]
            return (RandomCategoricalDelay(weights, shift=shift), s[ri+1:])
        elif fname in ["markovian", "markov"]:
            rem_str = s[li+1:].strip()
            dists = []
            transition_matrix = None
            while True:
                if rem_str.startswith("["):
                    # final argument: the transition matrix
                    li = rem_str.find(")")
                    assert li > 0
                    transition_matrix = json.loads(rem_str[:li].strip())
                    rem_str = rem_str[li+1:]
                    break
                else:
                    (a, rem_str) = _partial_delay_from_string(rem_str)
                    dists.append(a)
                    rem_str = rem_str.strip()
                    assert rem_str.startswith(","), "Expected a comma after argument"
                    rem_str = rem_str[1:].strip()
            return (HiddenMarkovianDelay(dists, transition_matrix), rem_str)
        elif fname in ["mm1q", "mm1queue", "mm1", "m/m/1"]:
            args = [a.strip() for a in s[li+1:ri].split(",")]
            if len(args) != 2:
                raise ValueError("Expected 2 arguments to an M/M/1 queue delay distribution")
            (rate_arrive, rate_service) = (float(args[0]), float(args[1]))
            return (MM1QueueDelay(rate_arrive, rate_service), s[ri+1:])
        elif fname in ["dataset"]:
            args = [a.strip() for a in s[li+1:ri].split(",")]
            if len(args) != 1:
                raise ValueError("Expected one argument (the path) to a dataset delay distribution")
            path = pathlib.Path(__file__).absolute().parent / f"results.{args[0]}.8ms-100k.json"
            return (DatasetDelay(str(path), qms=8), s[ri+1:])
        else:
            raise ValueError(f"Unknown function {fname}")
    raise ValueError(f"Unknown delay specification {s}")


def delay_from_string(s):
    tag = None
    dcidx = s.find("::")
    if dcidx > 0:
        tag = s[:dcidx]
        s = s[dcidx+len("::"):]

    (d, rem) = _partial_delay_from_string(s)
    if rem.strip() != "":
        raise ValueError(f"Found trailing contents in delay string: \"{rem}\"")

    d._tag = tag
    return d


class SimulatedInteractionLayer(gym.Wrapper, gym.utils.RecordConstructorArgs):
    def __init__(self, env,
                 delay=0,
                 horizon=1,
                 default_action=None,
                 flatten_observations=False):
        """
        Creates a wrapped environment representing a simulated interaction
        layer that wraps an underlying environment.

        This class becomes quite complex due to that we need to simulate the
        interaction behavior. Since the delay can be random, it is possible
        that the delay of an action is much lower than that of its predecessor.

        We abstract that away in the extended_step function by assuming that,
        for each observation received, the next applied action is generated
        from that observation. Thus any action received will be generated from
        the latest available data.

        When applying actions, you can specify actions for any possible future
        time slot in which those actions should be inserted. The minimum for
        the actual latency will be the one inserted into the action buffer.

        Parameters
        ----------
        env : gym.core.Env
            The gym environment to wrap.

        delay : DelayProcess, optional
            A delay process specifying the interaction delay. Can also be
            specified as a constant or a distribution, which will be converted
            to the appropriate distribution. (Default: 0)

        horizon : int, optional
            The delay horizon, how many actions that can be stored in the
            action buffer. (Default: 1)

        default_action : np.ndarray, optional
            The default action to apply as soon as the environment is turned
            on. (Default: An array of zeros corresponding to the action space
            of the environment.)

        flatten_observations : bool
            Whether the observations should be flattened such that underlying
            state and buffered actions become part of the emitted observation.
            Useful if want to apply the delayed MDP setting for regular
            algorithms. (Default: False)
        """
        gym.utils.RecordConstructorArgs.__init__(
            self,
            delay=delay,
            horizon=horizon,
            default_action=default_action,
            flatten_observations=flatten_observations,
        )
        gym.Wrapper.__init__(self, env)

        if delay is None:
            delay = ConstantDelay(0)
        elif isinstance(delay, int):
            delay = ConstantDelay(delay)
        elif isinstance(delay, (scipy.stats.rv_discrete, scipy.stats._distn_infrastructure.rv_discrete_frozen)):
            delay = RandomDiscreteDelay(delay)
        elif isinstance(delay, (scipy.stats.rv_continuous, scipy.stats._distn_infrastructure.rv_continuous_frozen)):
            delay = RandomRoundedDelay(delay)
        elif not isinstance(delay, DelayProcess):
            raise ValueError(f"The delay must either be an integer, a distribution, or an explicit delay process. Got {type(delay)}")

        assert delay._minsample() > 0, "Must have a positive minimum delay to use this wrapper"

        #assert isinstance(delay, scipy.stats.rv_discrete), "delay must be modeled as a discrete distribution"
        assert isinstance(horizon, int) and horizon > 0
        assert env.observation_space.dtype.kind == "f", f"Only numpy float types supported as of now (got {env.observation_space.dtype})"

        self.delay = delay
        self.horizon = horizon
        self.default_action = default_action
        if self.default_action is None:
            self.default_action = np.zeros(self.env.action_space.shape, dtype=self.env.action_space.dtype)

        assert self.default_action in self.env.action_space

        # Setup the observation and action spaces
        self.flatten_observations = flatten_observations
        if flatten_observations:
            # Encode the X = (s,a_0,...,a_{l-1}) tuple as a single numpy array
            assert isinstance(env.observation_space, gym.spaces.Box), "Flattening of observations requires Box space"
            assert len(env.observation_space.shape) == 1, "Only 1-dim spaces supported as of now"
            obs_len = env.observation_space.shape[0]
            if isinstance(env.action_space, gym.spaces.Box):
                assert len(env.action_space.shape) == 1, "Only 1-dim spaces supported as of now"
                act_len = env.action_space.shape[0]
                act_low = env.action_space.low
                act_high = env.action_space.high
            elif isinstance(env.action_space, gym.spaces.Discrete):
                act_len = 1
                act_low = 0
                act_high = env.action_space.n - 1
            else:
                raise ValueError(f"Unsupported action space: {env.action_space}")
            flatobs_low = np.zeros((obs_len + self.horizon*act_len,), dtype=np.float32)
            flatobs_high = np.zeros((obs_len + self.horizon*act_len,), dtype=np.float32)
            flatobs_low[0:obs_len] = env.observation_space.low
            flatobs_high[0:obs_len] = env.observation_space.high
            for l in range(self.horizon):
                flatobs_low[obs_len + l*act_len:obs_len + (l+1)*act_len] = act_low
                flatobs_high[obs_len + l*act_len:obs_len + (l+1)*act_len] = act_high
            flatobs_space = gym.spaces.Box(low=flatobs_low, high=flatobs_high, dtype=np.float32)
            self.raw_obs_len = obs_len
            self.raw_act_len = act_len
            self._observation_space = flatobs_space
            self._action_space = self.env.action_space
        else:
            self._observation_space = self.env.observation_space
            self._action_space = self.env.action_space

        self._t: int = None
        self._action_buffer: ExtendedState = None
        self._in_transit: List[Tuple[int, int, int, "Horizon of Actions"]] = None
        self._blank_delay_info = {"samples": []}
        self._delay_info = copy.deepcopy(self._blank_delay_info)

    def delaycopy(self, delay, horizon=None):
        """
        Constructs a copy of this environment, but with a new delay.

        If this class is subclassed, then this ought to be subclassed as well.
        """
        if horizon is None:
            horizon = self.horizon

        return SimulatedInteractionLayer(
            env=copy.copy(self.env),
            delay=delay,
            horizon=horizon,
            flatten_observations=self.flatten_observations,
        )

    @property
    def latency(self): return self.delay # deprecated alias latency==delay

    def _internal_reset(self, *,
                        seed: Optional[int] = None,
                        options: Optional[Dict[str, Any]] = None):
        """
        Resets the action buffer and the underlying environment.
        """
        s0_raw, info = self.env.reset(seed=seed, options=options)

        self._t = 0
        self._action_buffer = ExtendedState(
            t=None,
            t_origin=None,
            delay=None,
            delayshift=0,
            s_obs=None,
            a_mem=np.array([self.default_action]*self.horizon),
        )
        self._in_transit = []

        ext_s0 = self._action_buffer._replace(
            t=self._t,
            s_obs=s0_raw,
        )

        self._delay_info = copy.deepcopy(self._blank_delay_info)
        info["delay"] = copy.deepcopy(self._delay_info)

        return (ext_s0, info)

    def _internal_step(self):
        """
        Step in the underlying environment and updates the action buffer.

        This returns an observation at time t, and a_mem in the extended
        observation indicates the upcoming actions that will be applied in the
        next time step.
        """

        # Step 1. Perform step in environment
        a = self._action_buffer.a_mem[0]
        s_raw, r, terminated, truncated, info_dict = self.env.step(a)
        self._t += 1

        # Step 2. Update action buffer with new actions.
        # (These can be thought of as having arrived at the interaction layer
        #  while steps were happening in the environment.)
        if len(self._in_transit) > 0 and (self._in_transit[0][0] == self._t):
            # There has arrived an action for this timestep
            (in_t, in_gen, in_delay, in_actions) = self._in_transit[0]
            self._action_buffer = self._action_buffer._replace(
                t_origin=in_gen,
                delay=in_delay,
                delayshift=0,
                a_mem=np.array(in_actions),
            )
            self._in_transit = self._in_transit[1:]
        else:
            # No new action to apply yet, repeat last timestep
            self._action_buffer = self._action_buffer._replace(
                delayshift=self._action_buffer.delayshift + 1,
                a_mem=np.concatenate([
                    self._action_buffer.a_mem[1:],
                    np.array([self._action_buffer.a_mem[-1]]),
                ]),
            )

        # Step 3. Return the action buffer contents with the observed state.
        ext_obs = self._action_buffer._replace(
            t=self._t,
            s_obs=s_raw,
        )

        info_dict["delay"] = copy.deepcopy(self._delay_info)
        return (ext_obs, r, terminated, truncated, info_dict)

    def _internal_add_to_actionbuffer(self, action_seqs):
        """
        Adds actions to action buffer. The input must be a labelled list of
        action sequences, on the form of
        [
           (t_gen, t_0, [a0_0, a0_1, a0_1, ...]),
           (t_gen, t_1, [a1_0, a1_1, a1_2, ...]),
           ...
        ]

        The caller must guarantee that the list is non-empty and that none of
        the actions are the empty action.

        See inline comments for how this is added to the action buffer.
        """
        d = self.delay.sample()
        self._delay_info["samples"].append(d)

        apply_t = self._t + d # the time at which actions will be inserted

        a_selected = None
        for (a_gen, a_delay, a_seq) in action_seqs:
            if a_delay is None:
                a_apply = apply_t
            else:
                a_apply = a_gen + a_delay

            # d = app - gen
            # gen = app - d - shift
            if a_apply == apply_t:
                a_delay = a_apply - a_gen
                a_selected = (apply_t, a_gen, a_delay, a_seq)
                break

        # Keep it simple, discard anything that doesn't match the provided delay
        if a_selected is None:
            #LOG.warning("Did not provide an action with a low enough time to be applied at, discarding this action")
            return

        for i in range(len(self._in_transit)):
            (in_t, in_gen, in_delay, in_actions) = self._in_transit[i]
            if apply_t <= in_t:
                # Everyhing behind from i and forward is outdated
                self._in_transit = self._in_transit[:i]
                break

        self._in_transit.append(a_selected)
        return

    def _transform_extended_observation(self, ext_obs : ExtendedState, as_ndarray=False):
        """
        Performs transformations on extended observation from the internal
        representation to the representation that the user expects.
        """
        if self.flatten_observations:
            """Flatten the S = (s,a_0,...,a_{l-1}) into a single array value"""
            s_flat = np.zeros(self.observation_space.shape, dtype=np.float32)
            s_flat[0:self.raw_obs_len] = ext_obs.s_obs
            offset = self.raw_obs_len
            for i, a in enumerate(ext_obs.a_mem):
                s_flat[offset:offset+self.raw_act_len] = a
                offset += self.raw_act_len
            s = s_flat
        else:
            s = ext_obs.s_obs

        if as_ndarray:
            a_mem = np.asarray(ext_obs.a_mem)
        else:
            a_mem = ext_obs.a_mem

        return ext_obs._replace(s_obs=s, a_mem=a_mem)

    def extended_reset(self, *,
                       seed: Optional[int] = None,
                       options: Optional[Dict[str, Any]] = None,
                       as_ndarray: bool = False):

        s0_ext, info = self._internal_reset(seed=seed, options=options)
        if seed is not None:
            self.observation_space.seed(seed)
            self.action_space.seed(seed)

        ext_obs = self._transform_extended_observation(s0_ext, as_ndarray=as_ndarray)

        return ext_obs, info

    def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None):
        ext_s, info = self.extended_reset(seed=seed, options=options)
        return ext_s.s_obs, info

    def extended_step(self, action_packet, as_ndarray=False):
        """
        An extended step, taking an action packet as input. An action packet
        consists of a timetag of the observation packet it was generated from
        together with a matrix where the index of each row indicate the delay
        it should be applied for.

        Example of an action packet:
        (
          t,
          [[a1_1, a1_2, a1_3, ...],
           [a2_1, a2_2, a2_3, ...],
           ...
           [aL_1, aL_2, aL_3, ...]]
        )

        If this action packet arrives at time t + d, then the contents of the
        action buffer would be replaced by [ad_1, ad_2, ad_3, ...].

        See the _internal_step function for details on how this works.
        """
        if StationaryAction == action_packet:
            # This is a stationary action, so we don't insert anything into the action buffer.
            pass
        else:
            (t, M) = action_packet # M: []
            assert M.shape[0] > 0, "Must provide at least one action"
            assert M.shape[1] == self.horizon, f"Number of matrix columns does not match the horizon (got {M.shape[1]})"

            action_seqs = [
                (t, i + 1, M[i])
                for i in range(M.shape[0])
            ]
            self._internal_add_to_actionbuffer(action_seqs)

        ext_obs, r, terminated, truncated, info_dict = self._internal_step()
        new_ext_obs = self._transform_extended_observation(ext_obs, as_ndarray=as_ndarray)

        return new_ext_obs, r, terminated, truncated, info_dict

    def step(self, action):
        """
        A regular _opaque_ step function to be used with regular MDPs. When
        this packet arrives it will replace the entire action buffer with this
        action.
        """

        action_seqs = [(self._t, None, np.stack([action]*self.horizon))]
        self._internal_add_to_actionbuffer(action_seqs)

        ext_obs, r, terminated, truncated, info_dict = self._internal_step()
        new_ext_obs = self._transform_extended_observation(ext_obs, as_ndarray=False)

        return new_ext_obs.s_obs, r, terminated, truncated, info_dict

    def seed(self, *args, **kwargs): return self.env.seed(*args, **kwargs)
    def close(self, *args, **kwargs): return self.env.close(*args, **kwargs)


class ActionMemorizer(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
    def __init__(self, env, horizon=1, default_action=None, obs_passthrough=False):
        """
        Creates a wrapper that memorizes the n last actions (specified by
        horizon).

        Parameters
        ----------
        env : gym.core.Env
            The gym environment to wrap.

        delay : DelayProcess, optional
            A delay process specifying the interaction delay. Can also be
            specified as a constant or a distribution, which will be converted
            to the appropriate distribution. (Default: 0)

        horizon : int, optional
            The delay horizon, how many actions that will be memorized.
            (Default: 1)

        default_action : np.ndarray, optional
            The default action to to fill the initial space when no other
            actions are provided. (Default: An array of zeros corresponding to
            the action space of the environment.)

        obs_passthrough : bool, optional
            Do not report back any of the memorized actions, just pass through
            the underlying observation.
            (Default: False)
        """
        horizon = int(horizon)
        assert isinstance(horizon, int) and horizon > 0

        gym.utils.RecordConstructorArgs.__init__(self, horizon=horizon, obs_passthrough=obs_passthrough)
        gym.ObservationWrapper.__init__(self, env)

        self._horizon = horizon
        self._default_action = default_action
        self._obs_passthrough = obs_passthrough
        if self._default_action is None:
            self._default_action = np.zeros(env.action_space.shape, dtype=env.action_space.dtype)

        assert self._default_action in env.action_space

        assert isinstance(env.observation_space, gym.spaces.Box), "Requires Box obs space (TODO: can use other things here too...)"
        assert isinstance(env.action_space, gym.spaces.Box), "Requires Box avt space (TODO: can use other things here too...)"

        assert len(env.observation_space.shape) > 0, "Currently only supports boxed space with at least 1 dimension"
        assert len(env.action_space.shape) > 0, "Currently only supports boxed space with at least 1 dimension"

        if self._obs_passthrough:
            self._observation_space = env.observation_space
        else:
            obs_low = np.concatenate(
                [env.observation_space.low.flatten()] +
                [env.action_space.low.flatten()] * self._horizon
            )
            obs_high = np.concatenate(
                [env.observation_space.high.flatten()] +
                [env.action_space.high.flatten()] * self._horizon
            )
            self._observation_space = gym.spaces.Box(low=obs_low, high=obs_high, dtype=np.float32)

        self._action_trace = deque(maxlen=self._horizon)
        self._t = 0

    @property
    def delay(self):   return self.env.delay
    @property
    def horizon(self): return self._horizon
    @property
    def flatten_observations(self): return self.env.flatten_observations
    @property
    def default_action(self): return self._default_action

    def reset(self, *args, **kwargs):
        self._action_trace.clear()
        self._t = 0
        for t_delay in range(self._horizon):
            self._action_trace.append(copy.copy(self._default_action))
        return gym.ObservationWrapper.reset(self, *args, **kwargs)

    def step(self, action):
        self._action_trace.append(copy.copy(action))
        if isinstance(self.env, SimulatedInteractionLayer):
            atl = list(self._action_trace) # atl = action trace list
            # This is the constant-delay agumentation from the paper
            M = np.stack([
                np.stack(atl[i:] + ([atl[-1]] * i))
                for i in range(self._horizon)
            ]) # [L, L, A]

            actpkt = (self._t, M)

            ext_obs, reward, terminated, truncated, info = self.env.extended_step(actpkt)

            self._t += 1
            assert ext_obs.t == self._t, f"Mismatch: {ext_obs.t} != {self._t}"
            return self.observation(ext_obs.s_obs), reward, terminated, truncated, info
        else:
            raise NotImplementedError("This should never happen during any of our benchmarks...")
            #observation, reward, terminated, truncated, info = self.env.step(action)
            #return self.observation(observation), reward, terminated, truncated, info

    def observation(self, observation):
        assert len(self._action_trace) == self._horizon
        if self._obs_passthrough:
            return observation
        else:
            return np.concatenate(
                [observation.flatten()] +
                [a.flatten() for a in self._action_trace],
                dtype=np.float32
            )


class NoisyActionWrapper(gym.ActionWrapper, gym.utils.RecordConstructorArgs):
    """
    Adds a bit of gaussian noise to actions prior to applying them to the
    environment. This can be used to induce uncertainty into environments where
    we do not have control of the internal state.
    """
    def __init__(self, env, noise=0.05):
        """
        Parameters
        ----------
        env : gym.Env
            The environment to wrap.

        noise : float (optional)
            Proportional gaussian standard deviation noise w.r.t. upper and
            lower bound of action space.
        """
        gym.utils.RecordConstructorArgs.__init__(self, noise=noise)
        gym.ActionWrapper.__init__(self, env)

        # Chec
        minmax_diff = np.nan_to_num(
            self.env.action_space.high - self.env.action_space.low,
            nan=1.0,
            posinf=1.0,
            neginf=1.0,
        )
        self._action_noise = minmax_diff * noise

    def action(self, action):
        """Adds noise to the action and ensures that it stays within the action space."""
        new_action = scipy.stats.norm.rvs(
            loc=action,
            scale=self._action_noise,
        ).clip(
            min=self.action_space.low,
            max=self.action_space.high,
        )
        return new_action



### Training loop for delayed MDP's ###

class AllSaveToggling(argparse.BooleanOptionalAction):
    def __call__(self, parser, namespace, values, option_string=None):
        toggle = bool(not option_string.startswith("--no-"))
        namespace.save_tensorboard = toggle
        namespace.save_periodic_model = toggle
        namespace.save_best_model = toggle
        namespace.save_full_model = toggle

training_arguments = ArgList(
    resume_previous = Arg("--resume", action=argparse.BooleanOptionalAction, default=True,
                          help=f"Resume from a previous checkpoint."),
    save_tensorboard = Arg("--tensorboard", action=argparse.BooleanOptionalAction, default=True,
                           help=f"Save tensorboard metrics."),
    save_periodic_model = Arg("--save-periodic", action=argparse.BooleanOptionalAction, default=True,
                           help=f"Save the state of the trained model at periodic intervals."),
    save_best_model = Arg("--save-best", action=argparse.BooleanOptionalAction, default=True,
                          help=f"Save the state of the best performing model."),
    save_full_model = Arg("--save-full", action=argparse.BooleanOptionalAction, default=True,
                          help=f"Save the state of the full model at the end of training."),
    disable_save = Arg("--save", action=AllSaveToggling, default=True,
                       _auto_default=False,
                       help=f"Enable or disable saving of all metrics and models to disk."),
    trainer_prefix = Arg("--trainer-prefix", metavar="PREFIX", type=str, default=None,
                         _auto_default=False,
                         help=f"The prefix to use to save/load the trainer. (Default: <training_alg>_L<latency>)"),
    logdir = Arg("--logdir", metavar="DIR", type=str, default="_logs",
                 help=f"The directory to use for saving tensorboard logs and best models to."),
    trainer_eval_latencies = Arg("--trainer-eval-latencies", metavar="N1,N2,...", type=at.nonnegint_commalist, default=None,
                                 help=f"Latencies to evaluate policies at, as a comma separated list"),
    trainer_save_interval = Arg("--trainer-save-interval", metavar="N", type=at.posint, default=100_000,
                                help=f"Interval in number of steps to save policy on"),
    trainer_iterations = Arg("--trainer-iterations", metavar="N", type=at.posint, default=200,
                             help=f"Number of iterations"),
    trainer_itersteps = Arg("--trainer-itersteps", metavar="N", type=at.posint, default=10_000,
                            help=f"Minimum number of training steps to perform in one iteration"),
    trainer_eval_runs = Arg("--trainer-eval-runs", metavar="N", type=at.posint, default=10,
                            help=f"Number of runs to average over when evaluating environment"),
    trainer_random_iterations = Arg("--trainer-random-iterations", metavar="N", type=at.nonnegint, default=1,
                                    help=f"How many of initial iterations that actions should be chosen randomly."),
)


def training_loop(cls, cls_argmaker, denv, mk_denv, args, name=None, logger=None):
    """
    Training loop for a delayed MDP environment
    Arguments
    ---------
    cls : TrainingAlgorithm class
    cls_argmaker : Function that returns a tuple with the arguments to the trainer class
    denv : An instantiated delayed MDP environment
    args : Command line arguments
    """
    from .training.utils import setup_logging

    if name is None:
        name = cls.__name__

    if logger is None:
        LOCAL_LOG = logging.getLogger(f"{__name__}:training_loop({name})")
        LOCAL_LOG.addHandler(logging.NullHandler())
    else:
        LOCAL_LOG = logger

    if args.trainer_prefix is None:
        trainer_prefix = f"{name}_L{denv.delay.tag}_hzn{denv.horizon}"
        if denv.flatten_observations:
            trainer_prefix += "_flatobs"
    else:
        trainer_prefix = args.trainer_prefix

    logdir = pathlib.Path(args.logdir) / trainer_prefix

    LOGFILE_PATH = logdir / "log.txt"

    setup_logging(logfile_name=LOGFILE_PATH)

    # Call the training loop, but catch any exceptions that occur.
    try:
        _wrapped_training_loop(cls, cls_argmaker, denv, mk_denv, args,
                               logdir=logdir,
                               trainer_prefix=trainer_prefix,
                               LOCAL_LOG=LOCAL_LOG)
    except Exception as e:
        LOCAL_LOG.error(f"Training loop failed due to {type(e).__name__} exception: {str(e)}\nTrace:{traceback.format_exc()}")
        raise e

def _wrapped_training_loop(cls, cls_argmaker, denv, mk_denv, args,
                           logdir,
                           trainer_prefix,
                           LOCAL_LOG,
                          ):
    """
    Wrapped version of the training loop, ensuring that all exceptions can be
    logged.
    """
    from .training.utils import log_write_setup_tensorboard

    MODEL_PATH = logdir / "checkpoints"
    EVALFILE_PATH = logdir / "eval.json"

    if args.trainer_eval_latencies is None:
        args.trainer_eval_latencies = {denv.delay.tag: denv.delay}

    os.makedirs(MODEL_PATH, mode=0o755, exist_ok=True)

    eval_jsondata = {
        "args": {
            "sys.argv": [str(a) for a in sys.argv],
            "argparse": {str(k): str(v) for k, v in vars(args).items()}
        },
        "trainer_prefix": trainer_prefix,
        "evaluations": [],
    }

    LOCAL_LOG.debug(f"logdir: {logdir}")
    LOCAL_LOG.debug(f"trainer_prefix: {trainer_prefix}")
    LOCAL_LOG.debug(f"args.trainer_eval_latencies: {args.trainer_eval_latencies}")
    LOCAL_LOG.debug(f"using device: {args.device}")

    # Best return for a policy on this particular latency
    best_policies = {dkey: {
        "path": MODEL_PATH / f"best_{dkey}",
        "denv": mk_denv(delay=d),
        "total_reward": -np.inf,
    } for dkey, d in args.trainer_eval_latencies.items()}

    trainer = None
    loaded_trainer = False

    if args.resume_previous:
        try:
            trainer = cls.load(MODEL_PATH / "model")
            loaded_trainer = True
        except Exception as e:
            LOCAL_LOG.warning(f"Could not load previous checkpoint, due to {type(e).__name__} exception: {str(e)}\nTrace:{traceback.format_exc()}")
            LOCAL_LOG.info("Creating new trainer")
            trainer = None
            loaded_trainer = False
    if trainer is None:
        LOCAL_LOG.info(f"Creating new trainer")
        trainer = cls(*cls_argmaker())
        if args.save_tensorboard:
            trainer.logname = log_write_setup_tensorboard(logdir / "tb", append_date=True)
    else:
        if args.save_tensorboard:
            log_write_setup_tensorboard(trainer.logname, append_date=False)

    # Find previous bests if we had loaded a trainer
    if loaded_trainer:
        try:
            with open(EVALFILE_PATH) as f:
                loaded_eval_jsondata = json.load(f)
        except Exception as e:
            LOCAL_LOG.warning("could not load previous eval json data")
        else:
            eval_jsondata = loaded_eval_jsondata

        for dkey in args.trainer_eval_latencies.keys():
            try:
                best_trainer = cls.load(best_policies[dkey]["path"])
                eval_info = best_trainer.evaluate_many(
                    runs=args.trainer_eval_runs,
                    render=False,
                    env=best_policies[dkey]["denv"],
                )
                best_policies[dkey]["total_reward"] = eval_info.r_mean
                LOCAL_LOG.info(f"Previous best on {dkey}: {best_policies[dkey]['total_reward']}")
            except Exception as e:
                LOCAL_LOG.warning(f"Could not load previous best trainer on {dkey}: {str(e)}")

    EVAL_SAVE_INTERVAL = args.trainer_save_interval
    next_eval_savepoint = trainer.total_iterations + EVAL_SAVE_INTERVAL
    next_eval_savepoint = int(np.floor(next_eval_savepoint / EVAL_SAVE_INTERVAL) * EVAL_SAVE_INTERVAL)

    # Function for estimating the total training time remaining
    iter_timestamps = []
    iter_timedeltas = []
    def estimate_itertime(i_remaining):
        nonlocal iter_timestamps, iter_timedeltas
        from math import fmod
        ret = {}
        t_now = time.time()
        if len(iter_timestamps) > 0:
            t_delta = t_now - iter_timestamps[-1]
            iter_timedeltas.append(t_delta)
        iter_timestamps.append(t_now)
        header_parts = []
        if len(iter_timedeltas) > 0:
            avg_time = sum(iter_timedeltas[-5:])/min(5, len(iter_timedeltas))
            rem_time = i_remaining * avg_time
            (hh, rem_time) = (int(rem_time / 3600.0), fmod(rem_time, 3600.0))
            (mm, rem_time) = (int(rem_time / 60.0), fmod(rem_time, 60.0))
            ss = int(rem_time)
            for tt, tname in [(hh, "hour"), (mm, "minute"), (ss, "second")]:
                if tt == 1:  header_parts.append(f"{tt} {tname}")
                elif tt > 1: header_parts.append(f"{tt} {tname}s")
        if len(header_parts) > 0:
            ret["remaining"] = ", ".join(header_parts)
        return ret

    LOCAL_LOG.debug("Starting training session...")
    for i in range(args.trainer_iterations):
        time_info = estimate_itertime(args.trainer_iterations - i)
        header_suffix = f" (estimated time remaining: {time_info['remaining']})" if "remaining" in time_info else ""

        LOCAL_LOG.info(f"Iteration {i+1}/{args.trainer_iterations}{header_suffix}")

        trainer.set_mode("train")
        trainer.train(
            n_steps=args.trainer_itersteps,
            follow_policy=bool(i >= args.trainer_random_iterations or loaded_trainer),
        )
        trainer.set_mode("eval")

        eval_entry = {
            "iterations": trainer.total_iterations,
            "latency_eval": {},
        }

        for dkey in args.trainer_eval_latencies.keys():
            LOCAL_LOG.debug(f"Evaluating for latency {dkey}...")
            eval_info = trainer.evaluate_many(
                runs=args.trainer_eval_runs,
                render=False,
                env=best_policies[dkey]["denv"],
            )
            r_total = eval_info.r_mean
            LOCAL_LOG.debug(f"Total Reward: {r_total:.5f} ± {eval_info.r_std:.3f}")
            LOCAL_LOG.debug(f"All rewards: {list(eval_info.r_all)}")
            LOCAL_LOG.debug(f"All trajectory lengths: {list(eval_info.l_all)}")
            trainer.record_evaluation_value(f"latency={dkey}", r_total)
            if r_total > best_policies[dkey]["total_reward"]:
                LOCAL_LOG.info(f"New best ({r_total:.2f}) for latency {dkey}")
                best_policies[dkey]["total_reward"] = r_total
                if args.save_best_model:
                    LOCAL_LOG.info("saving network")
                    trainer.save(best_policies[dkey]["path"], evalcopy=True)

            eval_entry["latency_eval"][dkey] = {
                "returns": [float(v) for v in list(eval_info.r_all)],
                "lengths": [int(v) for v in list(eval_info.l_all)],
                "delays": eval_info.delays,
            }

        eval_jsondata["evaluations"].append(eval_entry)
        with open(EVALFILE_PATH, "w+") as f:
            json.dump(eval_jsondata, f, indent=2)

        if trainer.total_iterations >= next_eval_savepoint:
            # Save state for evaluation purposes
            if args.save_periodic_model:
                trainer.save(MODEL_PATH / f"{trainer.total_iterations}", evalcopy=True)
            next_eval_savepoint += EVAL_SAVE_INTERVAL

        if args.save_full_model:
            trainer.save(MODEL_PATH / "model")

    LOCAL_LOG.debug("Training session finished")
