"""
World sampling utilities.
"""

from __future__ import annotations

import numpy as np

from .core import _resolve_world


def sample_world(N, mu, world="beta", rng=None, conc=6.0):
    """
    Generate a sequence X_1..N in [0,1].
    world ∈ {"beta", "beta_mixture"} or "random" for 50/50.
    """
    if rng is None:
        rng = np.random.default_rng()

    world = _resolve_world(world, rng)

    if world == "beta":
        a = max(1e-8, conc * mu)
        b = max(1e-8, conc * (1.0 - mu))
        return rng.beta(a, b, size=N)
    if world == "beta_mixture":
        a1, b1 = max(1e-8, conc * mu), max(1e-8, conc * (1.0 - mu))
        a2, b2 = max(1e-8, 2.0 * conc * mu), max(1e-8, 2.0 * conc * (1.0 - mu))
        mask = rng.random(N) < 0.5
        X = np.empty(N)
        X[mask] = rng.beta(a1, b1, size=mask.sum())
        X[~mask] = rng.beta(a2, b2, size=(~mask).sum())
        return X
    raise ValueError("world must be 'beta' or 'beta_mixture'")


def sample_world_batch(B, N, mu, world="beta", rng=None, conc=6.0):
    """Vectorized version of sample_world. Returns (B,N) float32."""
    if rng is None:
        rng = np.random.default_rng()
    B = int(B)
    N = int(N)

    world = _resolve_world(world, rng)

    if world == "beta":
        a = max(1e-8, conc * mu)
        b = max(1e-8, conc * (1.0 - mu))
        return rng.beta(a, b, size=(B, N)).astype(np.float32)

    if world == "beta_mixture":
        a1, b1 = max(1e-8, conc * mu), max(1e-8, conc * (1.0 - mu))
        a2, b2 = max(1e-8, 2.0 * conc * mu), max(1e-8, 2.0 * conc * (1.0 - mu))
        mask = (rng.random((B, N)) < 0.5)
        X1 = rng.beta(a1, b1, size=(B, N))
        X2 = rng.beta(a2, b2, size=(B, N))
        X = np.where(mask, X1, X2)
        return X.astype(np.float32)

    raise ValueError("world must be 'beta' or 'beta_mixture'")


__all__ = ["sample_world", "sample_world_batch"]
