from .uniform_timestepsampler import UniformTimestepSampler
from .shifted_logit_normal_timestepsampler import ShiftedLogitNormalTimestepSampler
from .dist_uniform_timestepsampler import DistUniformTimestepSampler

if __name__ == "__main__":
    import matplotlib.pyplot as plt  # type: ignore

    sampler = ShiftedLogitNormalTimestepSampler()
    for seq_length in [1024, 2048, 4096, 8192]:
        samples = sampler.sample(batch_size=1_000_000, seq_length=seq_length)

        # plot the histogram of the samples
        plt.hist(samples.numpy(), bins=100, density=True)
        plt.title(f"Timestep Samples for Sequence Length {seq_length}")
        plt.xlabel("Timestep")
        plt.ylabel("Density")
        plt.show()
