import numpy as np

lookup = {
(0, 0): np.array([[0.9, 0.1],
                    [0.1, 0.9]]),

(1, 0): np.array([[1.0, 0.0],
                    [0.2, 0.8]]),

(2, 0): np.array([[0.0, 1.0],
                    [0.3, 0.7]]),

(0, 1): np.array([[0.8, 0.2],
                    [1.0, 0.0]]),

(0, 2): np.array([[0.5, 0.5],
                    [0.0, 1.0]]),

(1, 2): 'nothing',
(2, 1): 'nothing',
(2, 2): 'nothing',
(1, 1): 'nothing'
}
def sample():
    """Draw one ground-truth sequence x₁ ~ q.  Returns int ndarray shape (2,)."""
    mapping = {0: (1, 1),
               1: (1, 2),
               2: (2, 1),
               3: (2, 2)}
    probs   = np.array([0.15, 0.5, 0.05, 0.3])
    idx     = np.random.choice(len(mapping), p=probs)
    return np.array(mapping[idx], dtype=int)



def perturb(seq, t):
    """
    Mask each position independently with probability 1-t
    (i.e. keep with probability t).  Mask token id is 0.
    """
    keep_mask = (np.random.rand(*seq.shape) > 1 - t)   # bool
    return seq * keep_mask.astype(int)                 # int array



def bound(n_samples=1_000_0000, report_every=100_000):
    """
    Monte-Carlo estimate of the MD4 / perplexity upper bound.
    Returns average loss PER TOKEN (natural-log cross-entropy).
    """
    loss_sum   = 0.0     # accumulate per-token losses
    token_cnt  = 0       

    for n in range(1, n_samples + 1):
        token_cnt += 1
        x1 = sample()            # ground truth             (shape (2,))
        t  = np.random.rand()    # U(0,1)
        xt = perturb(x1, t)      # partially masked version (shape (2,))

        if (xt != 0).all():      # nothing masked → skip
            continue

        probs = lookup[tuple(xt)]

        inv = 1.0 / (1.0 - t)    # importance weight

        if xt[0] == 0:           # first position is masked
            loss_sum  += -inv * np.log(probs[0, x1[0] - 1])
        if xt[1] == 0:           # second position masked
            loss_sum  += -inv * np.log(probs[1, x1[1] - 1])
        

        if report_every and n % report_every == 0:
            print(f"step {n:>8,d}   running-bound = {loss_sum/token_cnt:.4f}")

    return loss_sum / token_cnt

bound()