import numpy as np
from typing import Callable, Optional, Tuple, Dict

Array = np.ndarray

def _ensure_rng(rng: Optional[np.random.Generator]) -> np.random.Generator:
    """
    Normalize RNG input:
      - if `rng` is already a Generator, use it;
      - else create a new default Generator (seeded if `rng` was an int).
    """
    return rng if isinstance(rng, np.random.Generator) else np.random.default_rng(rng)

def _prep_grid(domain: Array, cells_per_dim: int):
    """
    Build a regular axis-aligned grid over a d-dimensional hyper-rectangle.

    Parameters
    ----------
    domain : (d, 2) array
        Bounds per dimension [[a1, b1], ..., [ad, bd]].
    cells_per_dim : int
        Number of equal-width cells per dimension.

    Returns
    -------
    edges : list[np.ndarray]
        Per-dimension bin edges (length cells_per_dim+1).
    widths : (d,) np.ndarray
        Cell width for each dimension (constant per dim).
    lo : (num_cells, d) np.ndarray
        Lower-left corner of each cell.
    cell_vol : float
        Cell volume (product of widths).
    idx : (num_cells, d) np.ndarray
        Grid indices for each cell.
    num_cells : int
        Total number of grid cells = cells_per_dim ** d.
    d : int
        Ambient dimension.
    """
    d = domain.shape[0]
    edges = [np.linspace(domain[j,0], domain[j,1], cells_per_dim+1) for j in range(d)]
    widths = np.array([e[1]-e[0] for e in edges])               # (d,) constant per dim
    num_cells = cells_per_dim ** d
    # cell index grid: (num_cells, d), each col in [0, cells_per_dim-1]
    idx = np.stack(np.unravel_index(np.arange(num_cells), (cells_per_dim,)*d), axis=1)
    lo = np.column_stack([edges[j][idx[:, j]] for j in range(d)])  # (num_cells, d)
    cell_vol = float(np.prod(widths))
    return edges, widths, lo, cell_vol, idx, num_cells, d

def sample_inhomogeneous_ppp_adaptive_fast(
    d: int,
    intensity_fn: Callable[[Array], Array],
    domain: Optional[Array] = None,
    cells_per_dim: int = 8,
    per_cell_probe: int = 8,
    safety_factor: float = 1.2,
    rng: Optional[np.random.Generator] = None,
    return_meta: bool = False,
    approximate: bool = False,   # if True, piecewise-constant approx (no thinning)
) -> Tuple[Array, Optional[Dict]]:
    """
    Vectorized adaptive-grid sampler for inhomogeneous Poisson point processes.

    Two modes
    ---------
    • approximate=True:
        piecewise-constant approximation per cell,
        draw Poisson counts with mean (cell_mean_intensity × cell_volume),
        then sample uniformly within cells (no thinning; fastest).
    • approximate=False:
        thinning using per-cell maxima; propose at (empirical max × safety_factor),
        accept with probability λ(x)/λ_max(cell) (more exact given a valid bound).

    Parameters
    ----------
    d : int
        Dimension of the domain.
    intensity_fn : Callable[[Array], Array]
        Vectorized function mapping X∈R^{n×d} → nonnegative intensities λ(X).
    domain : (d, 2) array, optional
        Per-dimension bounds; defaults to the unit hypercube [0,1]^d.
    cells_per_dim : int, default=8
        Grid resolution per dimension.
    per_cell_probe : int, default=8
        Random probe points per cell to estimate mean/max intensity.
    safety_factor : float, default=1.2
        Buffer applied to empirical per-cell maxima (thinning mode).
    rng : np.random.Generator or None
        Random generator (for reproducibility).
    return_meta : bool, default=False
        If True, also return proposal/acceptance diagnostics.
    approximate : bool, default=False
        If True, use the piecewise-constant approximation.

    Returns
    -------
    X : (n_points, d) np.ndarray
        PPP locations for a single realization.
    meta : dict, optional
        Only if return_meta=True: {'proposals', 'accepted', 'acceptance_rate'}.

    Notes
    -----
    - `intensity_fn` should be vectorized over rows and return nonnegative values.
    - Typical experiment settings use cells_per_dim=20, per_cell_probe=20, safety_factor=1.2.
    """
    rng = _ensure_rng(rng)
    domain = np.asarray(domain if domain is not None
                        else np.column_stack([np.zeros(d), np.ones(d)]), float)
    if domain.shape != (d, 2):
        raise ValueError("domain must have shape (d, 2)")

    edges, widths, lo, cell_vol, idx, num_cells, d = _prep_grid(domain, cells_per_dim)

    # --------- probe all cells at once to estimate per-cell max/mean ---------
    # probes: (num_cells, per_cell_probe, d)
    U = rng.random((num_cells, per_cell_probe, d))
    probes = lo[:, None, :] + U * widths[None, None, :]  # broadcast widths per dim
    lam_probe = np.asarray(intensity_fn(probes.reshape(-1, d))).reshape(num_cells, per_cell_probe)
    if np.any(lam_probe < 0):
        raise ValueError("Intensity function returned negative values.")

    if approximate:
        # ---- piecewise-constant approximation (no thinning) ----
        lam_cell = lam_probe.mean(axis=1)                      # per-cell mean λ
        n_prop = rng.poisson(lam_cell * cell_vol)              # (num_cells,)
        total = int(n_prop.sum())
        if total == 0:
            return (np.empty((0, d)), {"proposals": 0, "accepted": 0,
                    "acceptance_rate": 0.0}) if return_meta else np.empty((0, d))
        # sample points in cells with counts>0
        cell_ids = np.repeat(np.arange(num_cells), n_prop)     # (total,)
        # per-dim coordinates
        X = np.empty((total, d), dtype=float)
        r = rng.random((total, d))
        # lo per cell/dim for all points
        lo_rep = lo[cell_ids]          # (total, d)
        X = lo_rep + r * widths        # widths broadcasts (d,)
        # No thinning in approximate mode
        if return_meta:
            return X, {"proposals": total, "accepted": total, "acceptance_rate": 1.0}
        return X

    # exact (with thinning) using per-cell maxima
    lam_cell_max = safety_factor * lam_probe.max(axis=1)        # (num_cells,)
    # Skip cells with zero bound
    mask_pos = lam_cell_max > 0
    if not np.any(mask_pos):
        return (np.empty((0, d)), {"proposals": 0, "accepted": 0,
                "acceptance_rate": 0.0}) if return_meta else np.empty((0, d))

    lam_cell_max_pos = lam_cell_max[mask_pos]
    lo_pos = lo[mask_pos]
    # Propose hom. PPP per cell, vectorized
    n_prop = rng.poisson(lam_cell_max_pos * cell_vol)           # (M,)
    total = int(n_prop.sum())
    if total == 0:
        return (np.empty((0, d)), {"proposals": 0, "accepted": 0,
                "acceptance_rate": 0.0}) if return_meta else np.empty((0, d))

    cell_ids_rel = np.repeat(np.arange(np.count_nonzero(mask_pos)), n_prop)  # (total,)
    lo_rep = lo_pos[cell_ids_rel]                                           # (total, d)

    # Draw proposed points (vectorized)
    R = rng.random((total, d))
    X_prop = lo_rep + R * widths                                            # (total, d)

    # Evaluate λ at all proposed points (vectorized)
    lam_vals = np.asarray(intensity_fn(X_prop)).reshape(-1)                 # (total,)
    # Per-point dominating max for thinning
    lam_dom = lam_cell_max_pos[cell_ids_rel]                                # (total,)
    p = np.clip(lam_vals / lam_dom, 0.0, 1.0)
    keep = rng.random(total) < p
    X = X_prop[keep]

    if return_meta:
        acc = int(keep.sum())
        rate = float(acc / total)
        return X, {"proposals": int(total), "accepted": acc, "acceptance_rate": rate}
    return X





class poisson_stream:
    """
    Minimal wrapper to generate a sequence of PPP realizations (one per time step).

    Typical usage
    -------------
    • Synthetic streams: choose dimension d, pass pre-/post-change intensities,
      and concatenate realizations to form a stream with a known change point.
    • Real-data proxy: use d=2 with coordinates mapped to [0,1]^2 upstream,
      and generate one realization per day (or time unit).

    Notes
    -----
    • Grid settings in `.generate` mirror defaults used in experiments.
    • The caller controls when to switch the intensity function `lam`.
    """
    
    def __init__(self,dim):
        self.dim=dim
        # Default domain is the unit hypercube [0,1]^d; override by passing `domain` to the sampler.
        self.domain=np.array([[0.0, 1.0] for __ in range(self.dim)])


    
    
    def generate(self, N, lam,data):
        """
        Append N realizations to `data` using intensity function `lam`.

        Parameters
        ----------
        N : int
            Number of time steps to simulate.
        lam : Callable[[Array], Array]
            Vectorized intensity λ(x) ≥ 0.
        data : list
            Target list; each append is an array of shape (n_t, dim) for one step.

        Returns
        -------
        data : list
            The same list with N additional realizations.
        """
        for ii in range(N):
            # Optional progress print (kept commented):
            # if ii % 10 == 0:
            #     print(ii)

            pts, meta = sample_inhomogeneous_ppp_adaptive_fast(
                self.dim, lam, domain=self.domain,
                cells_per_dim=20, per_cell_probe=20, safety_factor=1.2,
                return_meta=True,  # set approximate=True for a speed-boosted approximation
                )
        
            data.append(np.array(pts))
        
        return data
