import time
import numpy as np
import torch

x = np.random.rand(10000)
z = torch.tensor(x)
x = x / x.sum()

t0 = time.time()
for i in range(100):
    idx = torch.multinomial(z, 1, replacement=False)
t1 = time.time()
for i in range(100):
    idx = torch.multinomial(z, 1, replacement=True)
t2 = time.time()

rng = np.random.default_rng(0)
t3 = time.time()
for i in range(100):
    idx = rng.choice(10000, 3, p=x, replace=True)
t4 = time.time()
for i in range(100):
    idx = np.random.choice(10000, 3, p=x, replace=False)
t5 = time.time()

print(t1-t0)
print(t2-t1)
print(t4-t3)
print(t5-t4)