import logging
import numpy as np
import threading
import tree  # pip install dm_tree

from utils.numpy import SMALL_NUMBER
from utils.deprecation import Deprecated
from utils.typing import TensorStructType

logger = logging.getLogger(__name__)


class Filter:
    """Processes input, possibly statefully."""

    def apply_changes(self, other: "Filter", *args, **kwargs) -> None:
        """Updates self with "new state" from other filter."""
        raise NotImplementedError

    def copy(self) -> "Filter":
        """Creates a new object with same state as self.

        Returns:
            A copy of self.
        """
        raise NotImplementedError

    def sync(self, other: "Filter") -> None:
        """Copies all state from other filter to self."""
        raise NotImplementedError

    def reset_buffer(self) -> None:
        """Creates copy of current state and resets accumulated state"""
        raise NotImplementedError

    def as_serializable(self) -> "Filter":
        raise NotImplementedError


class NoFilter(Filter):
    is_concurrent = True

    def __call__(self, x: TensorStructType, update=True):
        # Process no further if already np.ndarray, dict, or tuple.
        if isinstance(x, (np.ndarray, dict, tuple)):
            return x

        try:
            return np.asarray(x)
        except Exception:
            raise ValueError("Failed to convert to array", x)

    def apply_changes(self, other: "NoFilter", *args, **kwargs) -> None:
        pass

    def copy(self) -> "NoFilter":
        return self

    def sync(self, other: "NoFilter") -> None:
        pass

    def reset_buffer(self) -> None:
        pass

    def as_serializable(self) -> "NoFilter":
        return self


# http://www.johndcook.com/blog/standard_deviation/
class RunningStat:
    def __init__(self, shape=None):
        self._n = 0
        self._M = np.zeros(shape)
        self._S = np.zeros(shape)

    def copy(self):
        other = RunningStat()
        other._n = self._n
        other._M = np.copy(self._M)
        other._S = np.copy(self._S)
        return other

    def push(self, x):
        x = np.asarray(x)
        # Unvectorized update of the running statistics.
        if x.shape != self._M.shape:
            raise ValueError(
                "Unexpected input shape {}, expected {}, value = {}".format(
                    x.shape, self._M.shape, x
                )
            )
        n1 = self._n
        self._n += 1
        if self._n == 1:
            self._M[...] = x
        else:
            delta = x - self._M
            self._M[...] += delta / self._n
            self._S[...] += delta * delta * n1 / self._n

    def update(self, other):
        n1 = self._n
        n2 = other._n
        n = n1 + n2
        if n == 0:
            # Avoid divide by zero, which creates nans
            return
        delta = self._M - other._M
        delta2 = delta * delta
        M = (n1 * self._M + n2 * other._M) / n
        S = self._S + other._S + delta2 * n1 * n2 / n
        self._n = n
        self._M = M
        self._S = S

    def __repr__(self):
        return "(n={}, mean_mean={}, mean_std={})".format(
            self.n, np.mean(self.mean), np.mean(self.std)
        )

    @property
    def n(self):
        return self._n

    @property
    def mean(self):
        return self._M

    @property
    def var(self):
        return self._S / (self._n - 1) if self._n > 1 else np.square(self._M)

    @property
    def std(self):
        return np.sqrt(self.var)

    @property
    def shape(self):
        return self._M.shape


class MeanStdFilter(Filter):
    """Keeps track of a running mean for seen states"""

    is_concurrent = False

    def __init__(self, shape, demean=True, destd=True, clip=10.0):
        self.shape = shape
        # We don't have a preprocessor, if shape is None (Discrete) or
        # flat_shape is Tuple[np.ndarray] or Dict[str, np.ndarray]
        # (complex inputs).
        flat_shape = tree.flatten(self.shape)
        self.no_preprocessor = shape is None or (
            isinstance(self.shape, (dict, tuple))
            and len(flat_shape) > 0
            and isinstance(flat_shape[0], np.ndarray)
        )
        # If preprocessing (flattning dicts/tuples), make sure shape
        # is an np.ndarray so we don't confuse it with a complex Tuple
        # space's shape structure (which is a Tuple[np.ndarray]).
        if not self.no_preprocessor:
            self.shape = np.array(self.shape)
        self.demean = demean
        self.destd = destd
        self.clip = clip
        # Running stats.
        self.rs = tree.map_structure(lambda s: RunningStat(s), self.shape)

        # In distributed rollouts, each worker sees different states.
        # The buffer is used to keep track of deltas amongst all the
        # observation filters.
        self.buffer = None
        self.reset_buffer()

    def reset_buffer(self) -> None:
        self.buffer = tree.map_structure(lambda s: RunningStat(s), self.shape)

    def apply_changes(
        self, other: "MeanStdFilter", with_buffer: bool = False, *args, **kwargs
    ) -> None:
        """Applies updates from the buffer of another filter.

        Args:
            other: Other filter to apply info from
            with_buffer: Flag for specifying if the buffer should be
                copied from other.

        Examples:
            >>> a = MeanStdFilter(())
            >>> a(1)
            >>> a(2)
            >>> print([a.rs.n, a.rs.mean, a.buffer.n])
            [2, 1.5, 2]
            >>> b = MeanStdFilter(())
            >>> b(10)
            >>> a.apply_changes(b, with_buffer=False)
            >>> print([a.rs.n, a.rs.mean, a.buffer.n])
            [3, 4.333333333333333, 2]
            >>> a.apply_changes(b, with_buffer=True)
            >>> print([a.rs.n, a.rs.mean, a.buffer.n])
            [4, 5.75, 1]
        """
        tree.map_structure(
            lambda rs, other_rs: rs.update(other_rs), self.rs, other.buffer
        )
        if with_buffer:
            self.buffer = tree.map_structure(lambda b: b.copy(), other.buffer)

    def copy(self) -> "MeanStdFilter":
        """Returns a copy of `self`."""
        other = MeanStdFilter(self.shape)
        other.sync(self)
        return other

    def as_serializable(self) -> "MeanStdFilter":
        return self.copy()

    def sync(self, other: "MeanStdFilter") -> None:
        """Syncs all fields together from other filter.

        Examples:
            >>> a = MeanStdFilter(())
            >>> a(1)
            >>> a(2)
            >>> print([a.rs.n, a.rs.mean, a.buffer.n])
            [2, array(1.5), 2]
            >>> b = MeanStdFilter(())
            >>> b(10)
            >>> print([b.rs.n, b.rs.mean, b.buffer.n])
            [1, array(10.0), 1]
            >>> a.sync(b)
            >>> print([a.rs.n, a.rs.mean, a.buffer.n])
            [1, array(10.0), 1]
        """
        self.demean = other.demean
        self.destd = other.destd
        self.clip = other.clip
        self.rs = tree.map_structure(lambda rs: rs.copy(), other.rs)
        self.buffer = tree.map_structure(lambda b: b.copy(), other.buffer)

    def __call__(self, x: TensorStructType, update: bool = True) -> TensorStructType:
        if self.no_preprocessor:
            x = tree.map_structure(lambda x_: np.asarray(x_), x)
        else:
            x = np.asarray(x)

        def _helper(x, rs, buffer, shape):
            # Discrete|MultiDiscrete spaces -> No normalization.
            if shape is None:
                return x

            # Keep dtype as is througout this filter.
            orig_dtype = x.dtype

            if update:
                if len(x.shape) == len(rs.shape) + 1:
                    # The vectorized case.
                    for i in range(x.shape[0]):
                        rs.push(x[i])
                        buffer.push(x[i])
                else:
                    # The unvectorized case.
                    rs.push(x)
                    buffer.push(x)
            if self.demean:
                x = x - rs.mean
            if self.destd:
                x = x / (rs.std + SMALL_NUMBER)
            if self.clip:
                x = np.clip(x, -self.clip, self.clip)
            return x.astype(orig_dtype)

        if self.no_preprocessor:
            return tree.map_structure_up_to(
                x, _helper, x, self.rs, self.buffer, self.shape
            )
        else:
            return _helper(x, self.rs, self.buffer, self.shape)

    def __repr__(self) -> str:
        return "MeanStdFilter({}, {}, {}, {}, {}, {})".format(
            self.shape, self.demean, self.destd, self.clip, self.rs, self.buffer
        )


class ConcurrentMeanStdFilter(MeanStdFilter):
    is_concurrent = True

    def __init__(self, *args, **kwargs):
        super(ConcurrentMeanStdFilter, self).__init__(*args, **kwargs)
        self._lock = threading.RLock()

        def lock_wrap(func):
            def wrapper(*args, **kwargs):
                with self._lock:
                    return func(*args, **kwargs)

            return wrapper

        self.__getattribute__ = lock_wrap(self.__getattribute__)

    def as_serializable(self) -> "MeanStdFilter":
        """Returns non-concurrent version of current class"""
        other = MeanStdFilter(self.shape)
        other.sync(self)
        return other

    def copy(self) -> "ConcurrentMeanStdFilter":
        """Returns a copy of Filter."""
        other = ConcurrentMeanStdFilter(self.shape)
        other.sync(self)
        return other

    def __repr__(self) -> str:
        return "ConcurrentMeanStdFilter({}, {}, {}, {}, {}, {})".format(
            self.shape, self.demean, self.destd, self.clip, self.rs, self.buffer
        )


def get_filter(filter_config, shape):
    # TODO(rliaw): move this into filter manager
    if filter_config == "MeanStdFilter":
        return MeanStdFilter(shape, clip=None)
    elif filter_config == "ConcurrentMeanStdFilter":
        return ConcurrentMeanStdFilter(shape, clip=None)
    elif filter_config == "NoFilter":
        return NoFilter()
    elif callable(filter_config):
        return filter_config(shape)
    else:
        raise Exception("Unknown observation_filter: " + str(filter_config))
