import numpy as np
import torch
from torch import Generator


def _sample_exponential(
    scale: float,
    generator: Generator,
    shape=(),
    dtype=torch.get_default_dtype(),
    device=torch.get_default_device(),
):
    x = torch.rand(shape, generator=generator, dtype=dtype, device=device)
    x = -scale * torch.log(x)
    return x


def sample_periods(
    time_on: float,
    time_off: float,
    time_bleach: float,
    gen: Generator,
    max_duration: float = -1,
    dtype=torch.get_default_dtype(),
    device=torch.get_default_device(),
):
    """
    Simulates periods of ON state transitions until a BLEACHED state is reached.
    """
    max_duration_enabled = max_duration >= 0
    periods_ON = []
    current_time = torch.tensor(0.0, dtype=dtype, device=device)
    while True:
        time_to_off = _sample_exponential(
            scale=time_off, generator=gen, dtype=dtype, device=device
        )
        time_to_bleach = _sample_exponential(
            scale=time_bleach, generator=gen, dtype=dtype, device=device
        )

        if time_to_off < time_to_bleach:
            # Transition to OFF state
            off_time = current_time + time_to_off
            if max_duration_enabled and off_time >= max_duration:
                p = torch.stack((current_time, max_duration))
                periods_ON.append(p)
                break
            p = torch.stack((current_time, off_time))
            periods_ON.append(p)
            # Transition back to ON state
            current_time = off_time + _sample_exponential(
                scale=time_on, generator=gen, dtype=dtype, device=device
            )
            if max_duration_enabled and current_time >= max_duration:
                break
        else:
            # Transition to BLEACHED state
            bleach_time = current_time + time_to_bleach
            if max_duration_enabled and bleach_time >= max_duration:
                bleach_time = max_duration
            p = torch.stack((current_time, bleach_time))
            periods_ON.append(p)
            break

    periods_tensor = torch.stack(periods_ON)
    return periods_tensor


def batch_sample_periods(
    n: int,
    time_on: float,
    time_off: float,
    time_bleach: float,
    n_periods_max: int,
    gen: np.random.Generator,
    dtype=torch.get_default_dtype(),
    device=torch.get_default_device(),
) -> torch.Tensor:
    """
    Batch simulates ON periods until BLEACHED for a Markov chain.

    Returns a tensor of shape [n, n_periods_max, 2], where each row is [start, end].
    Unused (padded) periods are [inf, inf].
    """
    periods = torch.full((n, n_periods_max, 2), torch.inf, dtype=dtype, device=device)
    current_time = torch.zeros(n, dtype=dtype, device=device)
    active = torch.ones(n, dtype=torch.bool, device=device)

    for i in range(n_periods_max):
        idx = torch.where(active)[0]
        if idx.numel() == 0:
            break
        # Sample waiting times for active chains.
        t_off = _sample_exponential(
            scale=time_off,
            shape=idx.numel(),
            device=device,
            dtype=dtype,
            generator=gen,
        )
        t_bleach = _sample_exponential(
            scale=time_bleach,
            shape=idx.numel(),
            device=device,
            dtype=dtype,
            generator=gen,
        )
        # Event occurs at the minimum of the two.
        event_time = torch.minimum(t_off, t_bleach)

        # Record period [start, end] for active chains.
        periods[idx, i, 0] = current_time[idx]
        periods[idx, i, 1] = current_time[idx] + event_time

        # Determine event type: OFF if t_off < t_bleach, else BLEACH.
        off_event = t_off < t_bleach
        if torch.any(off_event):
            off_idx = idx[off_event]
            # For OFF events, update current time by adding the OFF period and a waiting time.
            t_on = _sample_exponential(
                scale=time_on,
                shape=off_idx.numel(),
                device=device,
                dtype=dtype,
                generator=gen,
            )
            current_time[off_idx] += event_time[off_event] + t_on

        # For BLEACH events, mark chains as inactive.
        bleach_idx = torch.masked_select(idx, ~off_event)
        active[bleach_idx] = False
    return periods
