import numpy as np
import math
import scipy.stats as stats
import logging
from typing import Tuple, List

_logger = logging.getLogger(__name__)


def quantile_boundaries(n, tail_prob=0.0, scale=1.0):
    """Divide (-∞, ∞) into n equal probability intervals.


    Implemented by computing the normal distribution's inverse CDF of n+1
    equally spaced points from 0 to 1.

    If tail_prob is provided, then instead of [0, 1], we use
    [tail_prob, 1-tail_prob] as the range to be divided into n intervals.

    If tail_prob is not provided, it defaults to 0.0, which means that the
    first boundary is at -∞ and the last boundary is at ∞.

    Args:
        n: Number of intervals to divide the range into.
        tail_prob: The probability mass to exclude from the tails.
    Returns:
        A numpy array of n+1 boundaries.
    """
    # Validate inputs.
    if n <= 0 or not n == int(n):
        raise ValueError(f"n must be a positive integer ({n=})")
    if tail_prob and (tail_prob < 0 or tail_prob >= 0.5):
        raise ValueError(f"tail_prob must be between 0 and 0.5 ({tail_prob=})")

    quantiles = np.linspace(tail_prob, 1 - tail_prob, n + 1)
    # ppf is inverse of cdf.
    boundaries = stats.norm.ppf(quantiles, loc=0, scale=scale)

    # Check output invariant.
    assert np.all(
        np.diff(boundaries) > 0
    ), f"Boundaries should be increasing ({boundaries=})"
    return boundaries


class CycleNormal:
    """An intentionally terrible normal distribution sampler.

    Just cycles through the quantile boundaries of a normal distribution.
    """

    def __init__(self, n_steps, rng=None):
        self.n_steps = n_steps
        if rng is None:
            self._points = quantile_boundaries(n_steps + 1)[1:-1]
        else:
            self._points = rng.permutation(
                quantile_boundaries(n_steps + 1)[1:-1]
            )
        assert len(self._points) == n_steps
        self._current_step = 0

    def sample(self, loc=0.0, scale=1.0):
        point = self._points[self._current_step]
        self._current_step = (self._current_step + 1) % self.n_steps
        res = loc + point * scale
        return res, [self._current_step]


def uniform_T(n):
    """Transition matrix with uniform transitions to all states."""
    res = np.full(shape=(n, n), fill_value=1 / n)
    return res


def primary_plus_uniform_T(n, primary_p=0.9, rng=None):
    if not n or n < 2 or not isinstance(n, int):
        raise ValueError(f"n must be a positive integer greater than 1 ({n=})")
    if primary_p <= 0 or primary_p >= 1:
        raise ValueError(f"Must have 0 < primary_p < 1 ({primary_p=})")
    if rng is None:
        rng = np.random.default_rng()

    # Create a random permutation for the main transitions
    next_states = rng.permutation(n)

    T = np.ones((n, n)) * (1 - primary_p) / (n - 1)
    for i in range(n):
        T[i, next_states[i]] = primary_p

    if np.max(T) > primary_p:
        _logger.warning(
            f"Primary probability is low and not the largest "
            f"transition probability ({primary_p=} vs. {np.max(T)=})"
        )
    assert np.all(T >= 0) and np.all(T <= 1.0), f"{T=}"
    assert np.allclose(np.sum(T, axis=1), 1.0), f"{np.sum(T, axis=1)=}"
    return T


def primary_plus_distance_T(n, primary_p=0.9, rng=None):
    """Transition matrix with random primary (random) and adjacent transitions.

    Each state has a primary transition to a random state (default 90%
    probability). The remaining probability is distributed to states adjacent
    to the _from_ state, with exponentially decaying probability based on
    distance to that state.

    Args:
        n: Number of states.
        primary_p: Probability of primary transition.
        rng: Random number generator.
    Returns:
        np.ndarray: n x n transition matrix. Each row sums to 1.
    """
    if not n or n < 2 or not isinstance(n, int):
        raise ValueError(f"n must be a positive integer greater than 1 ({n=})")
    if primary_p <= 0 or primary_p >= 1:
        raise ValueError(f"Must have 0 < primary_p < 1 ({primary_p=})")
    if rng is None:
        rng = np.random.default_rng()
    T = np.zeros((n, n))

    # Create a random permutation for the main transitions
    next_states = rng.permutation(n)

    for i in range(n):
        distances = np.abs(np.arange(n) - i)
        base_probs = 1 / (distances + 1)
        base_probs[next_states[i]] = 0
        total_base_prob = np.sum(base_probs)
        assert total_base_prob > 0, f"{total_base_prob=}"
        base_probs = (1 - primary_p) * (base_probs / total_base_prob)
        T[i] = base_probs
        T[i, next_states[i]] = primary_p
    if np.max(T) > primary_p:
        _logger.warning(
            f"Primary probability is low and not the largest "
            f"transition probability ({primary_p=} vs. {np.max(T)=})"
        )
    assert np.all(T >= 0) and np.all(T <= 1.0), f"{T=}"
    assert np.allclose(np.sum(T, axis=1), 1.0), f"{np.sum(T, axis=1)=}"
    return T


def primary_plus_distance_Tv2(n, primary_p=0.9, rng=None):
    """Transition matrix with random primary (random) and adjacent transitions.

    Each state has a primary transition to a random state (default 90%
    probability). The remaining probability is distributed to states adjacent
    to the _from_ state, with exponentially decaying probability based on
    distance to that state.

    Args:
        n: Number of states.
        primary_p: Probability of primary transition.
        rng: Random number generator.
    Returns:
        np.ndarray: n x n transition matrix. Each row sums to 1.
    """
    if not n or n < 2 or not isinstance(n, int):
        raise ValueError(f"n must be a positive integer greater than 1 ({n=})")
    if primary_p <= 0 or primary_p >= 1:
        raise ValueError(f"Must have 0 < primary_p < 1 ({primary_p=})")
    if rng is None:
        rng = np.random.default_rng()
    T = np.zeros((n, n))

    # Create a random permutation for the main transitions
    next_states = rng.permutation(n)

    for i in range(n):
        distances = np.abs(np.arange(n) - i)
        base_probs = 1 / (distances + 1)
        total_base_prob = np.sum(base_probs)
        assert total_base_prob > 0, f"{total_base_prob=}"
        base_probs = (1 - primary_p) * (base_probs / total_base_prob)
        T[i] = base_probs
        T[i, next_states[i]] += primary_p
        if np.max(T[i]) > T[i, next_states[i]]:
            _logger.warning(
                f"Primary probability is low and not the largest "
                f"transition probability ({primary_p=} vs. {np.max(T[i])=})"
            )
    assert np.all(T >= 0) and np.all(T <= 1.0), f"{T=}"
    assert np.allclose(np.sum(T, axis=1), 1.0), f"{np.sum(T, axis=1)=}"
    return T


def pos_to_state(x, boundaries):
    """Find the state (interval) a position x belongs to."""
    s = np.digitize(x, boundaries) - 1
    # np.digitize can return outside the list of boundaries.
    s = max(0, min(s, len(boundaries) - 2))
    assert 0 <= s < len(boundaries) - 1, (s, len(boundaries))
    return s


def sample_next_state(x, boundaries, T, rng=None):
    """Sample next state from transition matrix probabilities.

    Args:
        x: Current position.
        boundaries: Boundaries of the intervals.
        T: Transition matrix.
        rng: Random number generator.
    Returns:
        A next state that was sampled.
    """
    if rng is None:
        rng = np.random.default_rng()
    if len(T) != len(boundaries) - 1:
        raise ValueError(
            f"Incompatible shapes ({T.shape=}, {len(boundaries)=})"
        )
    current_state = pos_to_state(x, boundaries)
    probs = T[current_state]
    next_state = rng.choice(len(probs), p=probs)
    return next_state


def transition_prob(x, inter_state, xx, boundaries, T, inner_pdf):
    """Calculate the forward and reverse transition probabilities.

    Get x's state, then sample a new temporary state based on the transition
    matrix. Then translate x to the corresponding position in the temporary
    state's interval. From this temporary position, sample a new position xx
    using the pdf.

    The forward probability is the product of the state transition probability
    and the probability of getting xx from the temporary position.

    The reverse probability is the sum of all possible transitions from xx to x.
    Every state is reachable from xx's state, and from each, there is a
    probability of sampling x from that state.

    Args:
        x: Current position.
        inter_state: The intermediate state (sampled through T).
        xx: The next position.
        boundaries: Boundaries of the intervals.
        T: Transition matrix.
        pdf: The probability density function of the sampler that generates
            xx from transformed(x).
    Returns:
        tuple: forward probability, reverse probability.
    """
    current_state = pos_to_state(x, boundaries)
    rel_x = to_rel_pos(x, boundaries)
    inter_x = from_rel_pos_single(rel_x, boundaries, inter_state)
    state_transition_p = T[current_state, inter_state]
    assert state_transition_p > 0, f"Can't transition to a state with 0 prob."
    forward_sample_p = inner_pdf(xx - inter_x)
    assert forward_sample_p > 0, f"{forward_sample_p=}, {xx=}, {inter_x=}"
    log_forward_p = math.log(state_transition_p) + math.log(forward_sample_p)
    final_state = pos_to_state(xx, boundaries)
    rel_xx = to_rel_pos(xx, boundaries)
    reverse_ps = from_rel_pos(rel_xx, boundaries)
    assert len(reverse_ps) == T.shape[0], f"{reverse_ps.shape=}, {T.shape[0]=}"
    # Final xx can transition to a position, xxx, in any other states based on
    # the transition matrix, T. Calling pdf(x - xxx) then gives the probability
    # of getting back to x, conditional on each state.
    # reverse_p = T[final_state] @ pdf(x - all_ps)
    x = np.array([x])
    log_reverse_p = np.log(T[final_state].T @ inner_pdf(x - reverse_ps))
    return log_forward_p, log_reverse_p


def to_rel_pos(x, boundaries):
    """Determine the relative position of x within its interval.

    Args:
        x: The position.
        boundaries: The boundaries of the intervals.
    Returns:
        The relative position, (a number between 0 and 1 inclusive).
    """
    if not np.all(np.isfinite(boundaries)):
        raise ValueError(f"Only supports finite boundaries ({boundaries=})")
    state = pos_to_state(x, boundaries)
    left = boundaries[state]
    right = boundaries[state + 1]
    x_clip = max(left, min(right, x))
    assert (x_clip == x) or state in {
        0,
        len(boundaries) - 2,
    }, f"Only the edge states are clipped. ({state=})"
    x = x_clip
    width = abs(right - left)
    if right > 0:
        p = x - left
    else:
        p = right - x
    assert p >= 0 and width > 0, f"{p=}, {width=}"
    return p / width


_range_1k_cache = np.arange(1000)


def from_rel_pos(rel_pos, boundaries):
    """Map a relative position to an absolute position in each interval.

    Symmetry around 0: intervals start at the edge closest to 0. If there
    are an odd number of intervals, the middle interval will contain zero;
    its direction will start from the left.

    Args:
        rel_pos: The relative position.
        boundaries: The boundaries of the intervals.
    Returns:
        A numpy array of absolute positions.
    """
    b_matrix = np.stack([boundaries[:-1], boundaries[1:]], axis=1)
    # l_to_r (for intervals on rhs of or containing 0):
    # pos = left + rel_pos * (right - left)
    #     = left(1 - rel_pos) + right * rel_pos
    # r_to_l (for intervals on lhs of 0):
    # pos = right - rel_pos * (right - left)
    #     = right(1 - rel_pos) + left * rel_pos
    # Must be >= and not > to handle intervals that contain 0.
    v = np.array([[rel_pos, 1 - rel_pos], [1 - rel_pos, rel_pos]])
    abs_ps = b_matrix @ v.T
    l_to_r = boundaries[1:] > 0
    abs_ps = abs_ps[
        # np.arange(len(abs_ps)),
        _range_1k_cache[0 : len(abs_ps)],
        l_to_r.astype(int),
    ]
    return abs_ps


def from_rel_pos_single(rel_pos, boundaries, target_state):
    """Map a relative position to actual position in target interval.


    Args:
        rel_pos: The relative position.
        boundaries: The boundaries of the intervals.
        target_state: The target interval.
    Returns:
        The actual position.
    """
    if not np.all(np.isfinite(boundaries)):
        raise ValueError(f"Only supports finite boundaries ({boundaries=})")
    left = boundaries[target_state]
    right = boundaries[target_state + 1]
    if right > 0:
        pos = left + rel_pos * (right - left)
    else:
        pos = right - rel_pos * (right - left)
    return pos


class MetropolisNormalQ:
    """Standard Metropolis sampler."""

    def __init__(
        self,
        target_pdf,
        initial_x=None,
        n_warmup=None,
        rng=None,
        proposal_sigma=None,
    ):
        """Initialize the sampler for len(boundaries) - 1 states.

        Args:
            target_pdf: The target probability density function (doesn't have
                to be normalized).
            initial_x: Initial state.
            n_warmup: Number of warmup steps.
            rng: A numpy random number generator.
        """
        self.n_warmup = 0 if n_warmup is None else n_warmup
        self._warmup_done = self.n_warmup == 0
        self.rng = rng or np.random.default_rng()
        self._x = self.rng.normal() if initial_x is None else initial_x
        self.proposal_sigma = proposal_sigma
        self.target_pdf = target_pdf

    def _sample(self):
        step = 0
        _ideally_done_before = int(5e3)
        while True:
            xx = self.rng.normal(loc=self._x, scale=self.proposal_sigma)
            # math.log raises an exception for negative inputs, np.log doesn't.
            prob_ratio = math.log(self.target_pdf(xx)) - math.log(
                self.target_pdf(self._x)
            )
            acceptance_ratio = np.exp(prob_ratio)
            # Shortcut for acceptance ratio > 1.
            if acceptance_ratio >= 1 or self.rng.random() < acceptance_ratio:
                # Accept the proposal.
                break
            step += 1
            if (step + 1) % _ideally_done_before == 0:
                _logger.warning(
                    f"Still no accept (step={step}). Current x={self._x}"
                )
        _logger.debug(f"Accepted proposal after {step} steps")
        self._x = xx
        return xx

    def sample(self, size=1):
        if not self._warmup_done:
            _logger.debug(f"[start] warmup for {self.n_warmup} steps")
            for _ in range(self.n_warmup):
                self._sample()
            self._warmup_done = True
            _logger.debug("[end] warmup")
        res = np.empty(size, dtype=float)
        for i in range(size):
            res[i] = self._sample()
        return res


class GridMetropolis:
    """Metropolis sampler imbued with a transition grid.

    Currently only supports 1D (using intervals)."""

    def __init__(
        self,
        target_pdf,
        boundaries,
        T,
        initial_x=None,
        n_warmup=None,
        rng=None,
        inner_sample_fn=None,
        inner_pdf=None,
    ):
        """Initialize the sampler for len(boundaries) - 1 states.

        Args:
            target_pdf: The target probability density function (doesn't have
                to be normalized).
            boundaries: A numpy array of n+1 boundaries.
            T: n x n transition matrix.
            initial_x: Initial state.
            n_warmup: Number of warmup steps.
            rng: A numpy random number generator.
        """
        if not (T.shape[0] == T.shape[1] == len(boundaries) - 1):
            raise ValueError(
                f"Incompatible shapes ({T.shape=} and {len(boundaries)=})"
            )
        self.boundaries = boundaries
        self.T = T
        self.n_states = self.T.shape[0]
        self.n_warmup = 0 if n_warmup is None else n_warmup
        self._warmup_done = self.n_warmup == 0
        self.rng = rng or np.random.default_rng()
        self._x = self.rng.normal() if initial_x is None else initial_x
        if inner_sample_fn is None and inner_pdf is None:
            self.inner_sample_fn = self.rng.normal
            self.inner_pdf = stats.norm.pdf
        elif inner_sample_fn is not None and inner_pdf is not None:
            self.inner_sample_fn = inner_sample_fn
            self.inner_pdf = inner_pdf
        else:
            raise ValueError(
                "Either both or neither of inner_sample_fn and inner_pdf "
                "should be provided."
            )
        self.target_pdf = target_pdf

    def _propose_next(self) -> Tuple[float, float, float, List]:
        """Propose the next sample point.

        Returns:
           - The proposed point.
           - The log of the proposal ratio.
           - The log of the transition probability ratio.
        """
        inter_state = sample_next_state(
            self._x, self.boundaries, self.T, self.rng
        )
        rel_x = to_rel_pos(self._x, self.boundaries)
        inter_pos = from_rel_pos_single(rel_x, self.boundaries, inter_state)
        # Inner state below is not used in the algorithm, but it's tracked
        # to log the inner sampler's state.
        inner_sample = self.inner_sample_fn(loc=inter_pos)
        # For tracking, we may get a tuple of (x, state) or just x if the
        # inner sampler doesn't support state reporting.
        if isinstance(inner_sample, tuple):
            xx = inner_sample[0]
            inner_state = inner_sample[1]
        else:
            xx = inner_sample
            inner_state = []
        forward_p, reverse_p = transition_prob(
            self._x,
            inter_state,
            xx,
            self.boundaries,
            self.T,
            self.inner_pdf,
        )
        # For tracking.
        all_states = [inter_state]
        all_states.extend(inner_state)
        _logger.debug(f"state={all_states}")
        return xx, forward_p, reverse_p, all_states

    def _sample(self):
        step = 0
        _ideally_done_before = int(5e3)
        while True:
            xx, forward_p, reverse_p, states = self._propose_next()
            transition_ratio = reverse_p - forward_p
            # math.log raises an exception for negative inputs, np.log doesn't.
            prob_ratio = math.log(self.target_pdf(xx)) - math.log(
                self.target_pdf(self._x)
            )
            acceptance_ratio = np.exp(prob_ratio + transition_ratio)
            # Shortcut for acceptance ratio > 1.
            if acceptance_ratio >= 1 or self.rng.random() < acceptance_ratio:
                # Accept the proposal.
                break
            step += 1
            if (step + 1) % _ideally_done_before == 0:
                _logger.warning(
                    f"Still no accept (step={step}). Current x={self._x}"
                )

        _logger.debug(f"Accepted proposal after {step} steps")
        self._x = xx
        return xx, states

    def sample(self, loc=0.0, scale=1.0, size=None):
        """
        Returns:
            ss: numpy array of samples.
            states: List of states for each sample. Each entry of states is
                a list of the states for the corresponding sample. States
                are ordered from outermost to innermost. So the last entry
                in the list is the innermost sampler's state.
        """
        loc = np.array(loc)
        scale = np.array(scale)
        if size is None:
            out_shape = loc.shape
            size = 1
        else:
            if int(size) != size:
                raise ValueError(f"Supports only integer size ({type(size)=})")
            size = int(size)
            out_shape = [size] + list(loc.shape)
        if not self._warmup_done:
            _logger.debug(f"[start] warmup for {self.n_warmup} steps")
            for _ in range(self.n_warmup):
                self._sample()
            self._warmup_done = True
            _logger.debug("[end] warmup")
        _log_interval = 1000
        loc_flat = loc.flatten()
        scale_flat = np.broadcast_to(scale, loc.shape).flatten()
        n_elements = int(np.prod(out_shape))
        res = np.empty(n_elements, dtype=float)
        row_len = len(loc_flat)
        all_states = []
        for i in range(size):
            for j in range(row_len):
                idx = i * row_len + j
                x, states = self._sample()
                res[idx] = loc_flat[j] + x * scale_flat[j]
                all_states.append(states)
                if (idx + 1) % _log_interval == 0:
                    _logger.info(f"sample {i+1}/{size}")
        ss = res.reshape(out_shape)
        # This next line allows sample to be used as a top-level function
        # that can make many samples and as a call for a single sample that
        # we will concat states. Not very clean.
        if len(all_states) == 1:
            all_states = all_states[0]
        return ss, all_states
