import gymnasium as gym
import logging
import json
import numpy as np
import math
import pathlib
import re
import scipy.stats
from collections import deque, namedtuple
from typing import Any, List, Tuple, Dict, Optional

LOG = logging.getLogger(__name__)
LOG.addHandler(logging.NullHandler())


class DelayProcess(gym.utils.EzPickle):
    """Base class for random delay processes."""
    def __init__(self, tag=None):
        self._tag = tag

    def __post_init__(self):
        assert self._minsample() > 0

    @property
    def tag(self):
        if self._tag is not None:
            return str(self._tag)
        else:
            return self.__str__()

    def reset(self):
        """Resets the delay process internal state."""
        pass

    def sample(self) -> int:
        """Samples a delay."""
        raise NotImplementedError

    def distribution(self) -> Dict[int, float]:
        """
        Return a distribution over delays. In not specially implemented then we
        just perform a Monte Carlo estimation.
        """
        LOG.warning(f"Sampling distribution for delay {type(self).__name__}")
        buckets = {}
        for _ in range(5000):
            d = self.sample()
            c = buckets.get(d, 0) + 1
            buckets[d] = c
        return {d: float(c) / 5000.0 for d, c in buckets.items()}

    @property
    def min(self):
        return self._minsample()

    def __str__(self): raise NotImplementedError
    def __int__(self): raise NotImplementedError

    def _minsample(self):
        """Return the minimum sample value for this distribution."""
        raise NotImplementedError

    def _partialcmp(self, other):
        """
        Partial ordering of a delay process
            -1 if self < other
             0 if self = other
             1 if self > other
          None if unrelated
        """
        return None

    def __eq__(self, other): return bool(self._partialcmp(other) == 0)
    def __ne__(self, other): return bool(self._partialcmp(other) != 0)
    def __lt__(self, other): return bool(self._partialcmp(other) == -1)
    def __le__(self, other): return bool(self._partialcmp(other) in [-1, 0])
    def __gt__(self, other): return bool(self._partialcmp(other) == 1)
    def __ge__(self, other): return bool(self._partialcmp(other) in [0, 1])


class DatasetDelay(DelayProcess):
    """Replay of a measured delay process."""
    def __init__(self, path, qms=8, **kwargs):
        gym.utils.EzPickle.__init__(self, path, qms=qms, **kwargs)
        self.qms = qms
        self.path = path
        with open(path) as f:
            measurements = json.load(f)
        self.delays = np.array([int(math.ceil(m["latency_ms"] / self.qms)) for m in measurements])
        assert len(self.delays) > 0
        self.delay_idx = 0
        self.reset()

    def reset(self):
        """Revert to the initial state, and reset all contained delays."""
        self.delay_idx = np.random.randint(0, len(self.delays))

    def sample(self):
        d = self.delays[self.delay_idx]
        self.delay_idx += 1
        if self.delay_idx >= len(self.delays):
            self.delay_idx = 0
        return int(d)

    def distribution(self):
        ticks = {}
        for d in self.delays:
            d = int(d)
            ticks[d] = ticks.get(d, 0.0) + 1.0

        for d in list(ticks.keys()):
            ticks[d] = ticks[d] / len(self.delays)

        return ticks

    def _minsample(self):
        return int(self.delays.min())

    def __str__(self):
        return f"DatasetDelay({self.path})"


class ConstantDelay(DelayProcess):
    """A constant delay process."""
    def __init__(self, constant : int, **kwargs):
        gym.utils.EzPickle.__init__(self, constant, **kwargs)
        self.constant = int(constant)

    def sample(self):
        return self.constant

    def distribution(self):
        return {self.constant: 1.0}

    def _minsample(self):
        return self.constant

    def _partialcmp(self, other):
        c = None
        if isinstance(other, ConstantDelay): c = other.constant
        elif isinstance(other, int):         c = other
        if c is not None:
            if self.constant < c: return -1
            if self.constant > c: return 1
            else:                 return 0
        return super()._partialcmp(other)

    def __str__(self):
        return f"ConstantDelay{self.constant}"

    def __int__(self):
        return self.constant


class RandomDiscreteDelay(DelayProcess):
    """A random delay process whos samples are all discrete and independent."""
    def __init__(self, dist : scipy.stats.rv_discrete, shift=1, **kwargs):
        super().__init__(**kwargs)
        self.dist = dist
        self.shift = shift

    def sample(self):
        return self.shift + max(0, int(self.dist.rvs()))

    def _minsample(self):
        return self.shift + int(max(0, self.dist.a))

    def _partialcmp(self, other):
        if isinstance(other, int):
            if other < self.min:
                return -1
        return super()._partialcmp(other)


class RandomRoundedDelay(DelayProcess):
    """A random delay process whos samples are rounded to the nearest integer."""
    def __init__(self, dist : scipy.stats.rv_continuous, shift=1, **kwargs):
        super().__init__(**kwargs)
        self.dist = dist
        self.shift = shift

    def sample(self):
        return self.shift + max(0, int(math.ceil(self.dist.rvs())))

    def distribution(self):
        ticks = {}
        c_prev = 0.0
        for t in range(1000):
            c = self.dist.cdf(t)
            if t > self.dist.mean() and c > (1 - 1e-5):
                break
            elif c > 1e-5:
                ticks[self.shift + t] = c - c_prev
            c_prev = c
        return ticks

    def _minsample(self):
        return self.shift + int(max(0, self.dist.a))

    def _partialcmp(self, other):
        if isinstance(other, int):
            if other < self.min:
                return -1
        return super()._partialcmp(other)

    def __str__(self):
        return "-".join(["RandomRoundedDelay", f"{self.dist.dist.name}"] + [f"{k}{v}" for k, v in self.dist.kwds.items()])


class RandomCategoricalDelay(DelayProcess):
    """A random delay process whos samples are all categorically distributed."""
    def __init__(self, weights, shift=1, **kwargs):
        super().__init__(**kwargs)
        gym.utils.EzPickle.__init__(self, weights, shift)
        weights = np.array(weights, dtype=np.float64)
        self.weights = weights / weights.sum()
        self.dist = scipy.stats.multinomial(1, self.weights)
        self.shift = shift

    def sample(self):
        return self.shift + self.dist.rvs(1).argmax(1).item()

    def distribution(self):
        return {
            self.shift + i: self.weights[i]
            for i, w in enumerate(self.weights)
        }

    def _minsample(self):
        return self.shift

    def __str__(self):
        return "Categorical(" + ",".join(
            [f"{w:.3}" for w in self.weights] + (
                [] if self.shift == 1 else [f"shift={self.shift}"]
            )
        ) + ")"


class HiddenMarkovianDelay(DelayProcess):
    def __init__(self,
                 delays: List[DelayProcess],
                 transition_matrix: np.ndarray,
                 initial_state: int = 0,
                 **kwargs):
        super().__init__(**kwargs)
        gym.utils.EzPickle.__init__(self, delays, transition_matrix, initial_state)
        self.delays = delays
        self.initial_state = initial_state
        assert len(self.delays) > 0
        assert self.initial_state in range(len(self.delays))
        assert all(isinstance(d, DelayProcess) for d in self.delays)

        L = len(self.delays)
        trmx = np.array(transition_matrix, dtype=np.float64)
        self.trmx = trmx / trmx.sum(axis=1).reshape(L, 1).repeat(L, 1)
        assert self.trmx.shape == (len(self.delays), len(self.delays))

        # Cumulative PMF for transition matrix
        self.cumtrmx = self.trmx.cumsum(axis=1)

        self._current_state = self.initial_state

    def reset(self):
        """Revert to the initial state, and reset all contained delays."""
        for d in self.delays:
            d.reset()
        self._current_state = self.initial_state

    def sample(self):
        sd = self.delays[self._current_state].sample()

        # Sample next state
        nxt = (np.random.rand() < self.cumtrmx[self._current_state]).argmax().item()
        self._current_state = nxt

        return sd

    def distribution(self):
        # Compute the steady/stationary state (assuming we can reach any state)
        itertrmx = self.trmx
        for i in range(1000):
            new_itertrmx = itertrmx @ self.trmx
            if np.isclose(itertrmx, new_itertrmx).all():
                LOG.debug(f"Markov distribution converged after {i} iterations")
                itertrmx = new_itertrmx
                break
            itertrmx = new_itertrmx

        ticks = {}
        for i, d in enumerate(self.delays):
            dticks = d.distribution()
            for t, v in dticks.items():
                new_v = ticks.get(t, 0.0) + (v * itertrmx[0, i])
                ticks[t] = new_v

        return ticks

    def _minsample(self):
        return min([d._minsample() for d in self.delays])

    def __str__(self):
        return "".join([
            "Markovian(",
            ", ".join([str(d) for d in self.delays] + [str(self.trmx.tolist())]),
            ")"
        ])


class MM1QueueDelay(DelayProcess):
    """
    This models an M/M/1 Queue where it in the best case only has a delay of 1,
    and it increases with the arrival times.

    The delay is the service time rounded up to the nearest whole number.
    """
    def __init__(self,
                 rate_arrive: float,
                 rate_service: float,
                 **kwargs):
        super().__init__(**kwargs)
        gym.utils.EzPickle.__init__(self, rate_arrive, rate_service)
        self.rate_arrive = rate_arrive
        self.rate_service = rate_service
        assert self.rate_arrive > 0, f"Arrival rate must be a positive value, got {self.rate_arrive}"
        assert self.rate_service > 0, f"Service rate must be a positive value, got {self.rate_service}"
        assert self.rate_arrive < self.rate_service, f"Service rate ({self.rate_service}) is not greater than arrival rate ({self.rate_arrive}), system is unstable."

        # From the documentation: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.expon.html#scipy.stats.expon
        # Need to use 1/lambda as the scale
        self._dist_arrive = scipy.stats.expon(scale=1 / self.rate_arrive)
        self._dist_service = scipy.stats.expon(scale=1 / self.rate_service)

        self.reset()

    @property
    def rho(self): return (self.rate_arrive / self.rate_service)

    def reset(self):
        """
        Revert the delay process to the initial state.

        The queue _q contains packet insertion times. The delay is the
        difference between time of insertion to time of service for a
        packet.
        """
        self._q = deque()
        self._next_arrival_time = self._dist_arrive.rvs().item()
        self._next_service_time = None

    def sample(self):
        if self._next_service_time is None:
            assert len(self._q) == 0
            t = self._next_arrival_time
            self._q.appendleft(t)
            # Setup next arrival and service time
            self._next_arrival_time = t + self._dist_arrive.rvs().item()
            self._next_service_time = t + self._dist_service.rvs().item()

        # Accept packets until we can serve
        while self._next_arrival_time < self._next_service_time:
            assert len(self._q) > 0
            t = self._next_arrival_time
            self._q.appendleft(t)
            self._next_arrival_time += self._dist_arrive.rvs().item()

        # Now it is time to service
        t = self._next_service_time
        t_inserted = self._q.pop()
        d = int(math.ceil(t - t_inserted))

        if len(self._q) > 0:
            self._next_service_time += self._dist_service.rvs().item()
        else:
            self._next_service_time = None

        return d

    def distribution(self):
        # Find when the probability becomes too small to care about
        # i.e. rho^k ~= 0.001
        #      k ~= log(0.001)/log(rho)
        #max_i = int(math.log(0.001) / math.log(self.rho)) + 1

        #ticks = {
        #    i: (1 - self.rho) * (self.rho ** i)
        #    for i in range(0, max_i)
        #}

        # Actually, we need to measure the busy time of the server instead...

        # For now, just simulate...
        return super().distribution()

    def _minsample(self):
        return 1

    def __str__(self):
        return "".join([
            f"MM1Queue({self.rate_arrive}, {self.rate_service}",
            #"" if self.maximum_state == 0 else f", maximum_state={self.maximum_state}",
            ")",
        ])


def _partial_delay_from_string(s) -> Tuple[DelayProcess, str]:
    s = s.strip()
    num_match = re.match(r"^\d+", s)
    if num_match is not None:
        (i1, i2) = num_match.span()
        return (ConstantDelay(int(s[i1:i2])), s[i2:])

    if s.find("(") > 0:
        li, ri = (s.find("("), s.find(")"))
        assert li < ri
        fname = s[:li].lower().strip()
        if fname in ["norm", "normal", "gaussian", "gauss"]:
            args = [a.strip() for a in s[li+1:ri].split(",")]
            if len(args) not in [2, 3]:
                raise ValueError("Expected 2 or 3 arguments to normal delay distribution")
            args.append(1) # default shift
            (loc, scale, shift) = (float(args[0]), float(args[1]), int(args[2]))
            return (RandomRoundedDelay(scipy.stats.norm(loc, scale), shift=shift), s[ri+1:])
        elif fname in ["weighted", "weight", "categorical", "cat", "ordinal", "ord"]:
            args = [a.strip() for a in s[li+1:ri].split(",")]
            if len(args) > 0 and args[-1].startswith("shift"):
                shiftarg = args[-1]
                args = args[:-1]
                eqidx = shiftarg.find("=")
                assert eqidx > 0
                shift = int(shiftarg[eqidx+1:].strip())
            else:
                shift = 1
            assert len(args) > 0, f"Must have at least 1 weight"
            weights = [float(w) for w in args]
            return (RandomCategoricalDelay(weights, shift=shift), s[ri+1:])
        elif fname in ["markovian", "markov"]:
            rem_str = s[li+1:].strip()
            dists = []
            transition_matrix = None
            while True:
                if rem_str.startswith("["):
                    # final argument: the transition matrix
                    li = rem_str.find(")")
                    assert li > 0
                    transition_matrix = json.loads(rem_str[:li].strip())
                    rem_str = rem_str[li+1:]
                    break
                else:
                    (a, rem_str) = _partial_delay_from_string(rem_str)
                    dists.append(a)
                    rem_str = rem_str.strip()
                    assert rem_str.startswith(","), "Expected a comma after argument"
                    rem_str = rem_str[1:].strip()
            return (HiddenMarkovianDelay(dists, transition_matrix), rem_str)
        elif fname in ["mm1q", "mm1queue", "mm1", "m/m/1"]:
            args = [a.strip() for a in s[li+1:ri].split(",")]
            if len(args) != 2:
                raise ValueError("Expected 2 arguments to an M/M/1 queue delay distribution")
            (rate_arrive, rate_service) = (float(args[0]), float(args[1]))
            return (MM1QueueDelay(rate_arrive, rate_service), s[ri+1:])
        elif fname in ["dataset"]:
            args = [a.strip() for a in s[li+1:ri].split(",")]
            if len(args) != 1:
                raise ValueError("Expected one argument (the path) to a dataset delay distribution")
            path = pathlib.Path(__file__).absolute().parent / f"results.{args[0]}.8ms-100k.json"
            return (DatasetDelay(str(path), qms=8), s[ri+1:])
        else:
            raise ValueError(f"Unknown function {fname}")
    raise ValueError(f"Unknown delay specification {s}")


def delay_from_string(s):
    tag = None
    dcidx = s.find("::")
    if dcidx > 0:
        tag = s[:dcidx]
        s = s[dcidx+len("::"):]

    (d, rem) = _partial_delay_from_string(s)
    if rem.strip() != "":
        raise ValueError(f"Found trailing contents in delay string: \"{rem}\"")

    d._tag = tag
    return d



class NoisyActionWrapper(gym.ActionWrapper, gym.utils.RecordConstructorArgs):
    """
    Adds a bit of gaussian noise to actions prior to applying them to the
    environment. This can be used to induce uncertainty into environments where
    we do not have control of the internal state.
    """
    def __init__(self, env, noise=0.05):
        """
        Parameters
        ----------
        env : gym.Env
            The environment to wrap.

        noise : float (optional)
            Proportional gaussian standard deviation noise w.r.t. upper and
            lower bound of action space.
        """
        gym.utils.RecordConstructorArgs.__init__(self, noise=noise)
        gym.ActionWrapper.__init__(self, env)

        # Chec
        minmax_diff = np.nan_to_num(
            self.env.action_space.high - self.env.action_space.low,
            nan=1.0,
            posinf=1.0,
            neginf=1.0,
        )
        self._action_noise = minmax_diff * noise

    def action(self, action):
        """Adds noise to the action and ensures that it stays within the action space."""
        new_action = scipy.stats.norm.rvs(
            loc=action,
            scale=self._action_noise,
        ).clip(
            min=self.action_space.low,
            max=self.action_space.high,
        )
        return new_action



