import torch
from mmhug.registry import HF_MODELS
from .base_timestepsampler import TimestepSampler


@HF_MODELS.register_module()
class UniformTimestepSampler(TimestepSampler):
    """Samples timesteps uniformly between min_value and max_value (default 0 and 1)."""

    def __init__(self, min_value: float = 0.0, max_value: float = 1.0):
        self.min_value = min_value
        self.max_value = max_value

    def sample(
        self,
        batch_size: int,
        seq_length: int | None = None,
        device: torch.device = None,
    ) -> torch.Tensor:  # noqa: ARG002
        return (
            torch.rand(batch_size, device=device) * (self.max_value - self.min_value)
            + self.min_value
        )

    def sample_for(self, batch: torch.Tensor) -> torch.Tensor:
        if batch.ndim != 3:
            raise ValueError(f"Batch should have 3 dimensions, got {batch.ndim}")

        batch_size, seq_length, _ = batch.shape
        return self.sample(batch_size, device=batch.device)
