import numpy as np

def gumbel_max(logits: np.array, gumbel: np.array=None) -> tuple[int, float]:
    if gumbel is None:
        gumbel = -np.log(-np.log(np.random.uniform(size=logits.shape)))
    scores = logits + gumbel
    # print(scores)
    return np.argmax(scores).item(), gumbel


# NOTE it is useless to return the samples
# TODO remove it
def gumbel_max_rejection_sampling(
    probs: np.array, 
    observation: int, 
    max_iterations: int=10000,
    n_samples: int=1000,
    verbose: bool=False,
) -> tuple[int, np.array]:
    samples = []
    G = []
    for _ in range(n_samples):
        out_ = None
        steps = 0
        while out_ != observation:
            out_, g = gumbel_max(np.log(probs))
            if verbose:
                print(out_, g)
            steps += 1
            if steps > max_iterations:
                print("Max iterations reached")
                return None
        samples.append(out_)
        G.append(g)
    return samples, G 

