import numpy as np
from typing import Callable, Sequence, Tuple, Optional

Array = np.ndarray


# ----------------------------
# PPP sampler (needed by generate_time_series)
# ----------------------------

def _build_1d_cdf_on_grid(p: Callable[[Array], Array], grid_x: Array) -> tuple[Array, float]:
    """
    Build a grid CDF for density proportional to p(x) on [grid_x[0], grid_x[-1]].
    Returns (cdf, Z) where Z = ∫ p(x) dx (trapz on grid).
    """
    x = grid_x
    w = np.asarray(p(x), dtype=float)
    w = np.clip(w, 0.0, np.inf)

    Z = float(np.trapz(w, x))
    if not np.isfinite(Z) or Z <= 0.0:
        raise ValueError(f"Integral Z=∫ p(x) dx must be positive and finite. Got Z={Z}")

    pdf = w / Z
    dx = (x[-1] - x[0]) / (len(x) - 1)

    cdf = np.empty_like(x)
    cdf[0] = 0.0
    cdf[1:] = np.cumsum(0.5 * (pdf[:-1] + pdf[1:]) * dx)

    cdf[-1] = 1.0
    cdf = np.clip(cdf, 0.0, 1.0)
    cdf = np.maximum.accumulate(cdf)
    return cdf, Z


def _inv_cdf_sample(u: Array, grid_x: Array, cdf: Array) -> Array:
    """Invert grid CDF via linear interpolation."""
    u = np.asarray(u, dtype=float)
    u = np.clip(u, 0.0, 1.0)
    return np.interp(u, cdf, grid_x)


def sample_ppp_meanfield_mixture2_once(
    ps1: Sequence[Callable[[Array], Array]],
    ps2: Sequence[Callable[[Array], Array]],
    domains: Sequence[Tuple[float, float]],
    w1: float,
    w2: float,
    grid_size: int = 4096,
    rng: Optional[np.random.Generator] = None,
) -> Array:
    """
    Sample ONE realization of a PPP on Π_j [a_j,b_j] with intensity:
        λ(x)= w1 * Π_j p1_j(x_j) + w2 * Π_j p2_j(x_j)

    Returns X with shape (n, d). n is random (Poisson).
    """
    if rng is None:
        rng = np.random.default_rng()

    d = len(domains)
    if len(ps1) != d or len(ps2) != d:
        raise ValueError("ps1, ps2, and domains must have the same length d.")
    if grid_size < 32:
        raise ValueError("grid_size should be >= 32.")
    if w1 < 0 or w2 < 0:
        raise ValueError("w1 and w2 must be nonnegative.")
    if w1 == 0 and w2 == 0:
        return np.empty((0, d), dtype=float)

    # grids
    grids = []
    for (a, b) in domains:
        if not (np.isfinite(a) and np.isfinite(b) and b > a):
            raise ValueError(f"Each domain must be finite with b>a. Got [{a},{b}]")
        grids.append(np.linspace(a, b, grid_size, dtype=float))

    # per-dim CDFs + integrals
    cdf1, cdf2 = [], []
    Z1 = np.empty(d, dtype=float)
    Z2 = np.empty(d, dtype=float)

    for j in range(d):
        c1, z1 = _build_1d_cdf_on_grid(ps1[j], grids[j])
        c2, z2 = _build_1d_cdf_on_grid(ps2[j], grids[j])
        cdf1.append(c1); Z1[j] = z1
        cdf2.append(c2); Z2[j] = z2

    def _mass(w: float, Z: Array) -> float:
        if w == 0.0:
            return 0.0
        # w * Π_j Z_j (computed stably)
        return float(np.exp(np.log(w) + np.sum(np.log(Z))))

    Lambda1 = _mass(w1, Z1)
    Lambda2 = _mass(w2, Z2)

    n1 = int(rng.poisson(Lambda1)) if Lambda1 > 0 else 0
    n2 = int(rng.poisson(Lambda2)) if Lambda2 > 0 else 0

    X1 = np.empty((n1, d), dtype=float)
    if n1 > 0:
        U1 = rng.random((n1, d))
        for j in range(d):
            X1[:, j] = _inv_cdf_sample(U1[:, j], grids[j], cdf1[j])

    X2 = np.empty((n2, d), dtype=float)
    if n2 > 0:
        U2 = rng.random((n2, d))
        for j in range(d):
            X2[:, j] = _inv_cdf_sample(U2[:, j], grids[j], cdf2[j])

    X = np.vstack([X1, X2]) if (n1 + n2) > 0 else np.empty((0, d), dtype=float)

    if X.shape[0] > 1:
        rng.shuffle(X, axis=0)

    return X


# ----------------------------
# 2D AR(1) with drift change (needed by generate_time_series)
# ----------------------------

def generate_ar2d_changepoint(
    sigma: float,
    A: np.ndarray,
    mu_before: np.ndarray,
    mu_after: np.ndarray,
    bb: int,
    N_total: int,
    x0: Optional[np.ndarray] = None,
) -> np.ndarray:
    """
    X[t] = A X[t-1] + eps_t + mu_before  if t < bb
    X[t] = A X[t-1] + eps_t + mu_after   if t >= bb
    eps_t ~ N(0, sigma I_2)
    """
    if N_total <= 0:
        raise ValueError("N_total must be a positive integer.")
    if sigma < 0:
        raise ValueError("sigma must be >= 0.")
    if not (0 <= bb <= N_total):
        raise ValueError(f"bb must satisfy 0 <= bb <= N_total, got bb={bb}, N_total={N_total}.")

    A = np.asarray(A, dtype=float)
    mu_before = np.asarray(mu_before, dtype=float).reshape(-1)
    mu_after = np.asarray(mu_after, dtype=float).reshape(-1)

    if A.shape != (2, 2):
        raise ValueError(f"A must have shape (2,2), got {A.shape}.")
    if mu_before.shape != (2,) or mu_after.shape != (2,):
        raise ValueError(f"mu_before and mu_after must have shape (2,), got {mu_before.shape}, {mu_after.shape}.")

    if x0 is None:
        x_prev = np.zeros(2, dtype=float)
    else:
        x0 = np.asarray(x0, dtype=float).reshape(-1)
        if x0.shape != (2,):
            raise ValueError(f"x0 must have shape (2,), got {x0.shape}.")
        x_prev = x0.copy()

    std = np.sqrt(sigma)
    X = np.empty((N_total, 2), dtype=float)

    for t in range(N_total):
        mu_t = mu_before if t < bb else mu_after
        eps_t = std * np.random.randn(2)
        x_t = A @ x_prev + eps_t + mu_t
        X[t] = x_t
        x_prev = x_t

    return X


# ----------------------------
# The only class you use
# ----------------------------

class generate_time_series:
    """
    Inputs/outputs match your current format:
      input:  ps1, ps2, ps3, ps4, domains, mu_before, mu_after, bb, N_total
      output: data (list), len(data)=N_total, data[t] is an (n_t, d) array
    """

    def __init__(
        self,
        ps1,
        ps2,
        ps3,
        ps4,
        domains,
        mu_before,
        mu_after,
        bb,
        N_total,
        A: Optional[np.ndarray] = None,
        sigma: float = 1,
        grid_size: int = 4096,
        clip_weights: bool = True,
        x0: Optional[np.ndarray] = None,
        rng: Optional[np.random.Generator] = None,
    ):
        self.ps1 = ps1
        self.ps2 = ps2
        self.ps3 = ps3
        self.ps4 = ps4
        self.domains = domains

        self.N_total = int(N_total)
        bb = int(bb)
        # If user gives bb outside [0, N_total], clamp (so it still runs)
        self.bb = max(0, min(bb, self.N_total))

        self.A = np.asarray(A if A is not None else np.array([[0.5, 0.1], [0.1, 0.5]]), dtype=float)
        self.sigma = float(sigma)
        self.grid_size = int(grid_size)
        self.clip_weights = bool(clip_weights)
        self.x0 = x0
        self.rng = rng  # can be None

        self.mean = generate_ar2d_changepoint(
            sigma=self.sigma,
            A=self.A,
            mu_before=mu_before,
            mu_after=mu_after,
            bb=self.bb,
            N_total=self.N_total,
            x0=self.x0,
        )

    def compute(self):
        data = []
        rng = self.rng if self.rng is not None else np.random.default_rng()

        for t in range(self.N_total):
            # choose which densities to use
            if t < self.bb:
                ps_a, ps_b = self.ps1, self.ps2
            else:
                ps_a, ps_b = self.ps3, self.ps4

            w1 = float(self.mean[t, 0])
            w2 = float(self.mean[t, 1])

            if self.clip_weights:
                w1 = max(0.0, w1)
                w2 = max(0.0, w2)

            X = sample_ppp_meanfield_mixture2_once(
                ps1=ps_a,
                ps2=ps_b,
                domains=self.domains,
                w1=w1,
                w2=w2,
                grid_size=self.grid_size,
                rng=rng,
            )
            data.append(X)

        return data
