import numpy as np

max_count = 1000

h = np.asarray([ 0,0 ]) # state
d = np.asarray([ 1,0 ]) # demo
r = np.asarray([ -1,0 ]) # resp
t = np.asarray([ 0,1 ]) # resp
embs = np.asarray([ d,r,t ])

w = np.asarray([
    [1,0,1,0],
    [0,1,0,1],
])
w_out = np.asarray([
    [1,-max_count],
    [0.75,0],
    [-max_count,1]
])
ws = {
    "w": w,
    "w_out": w_out,
}

def model(ws, state, inpt):
    """
    ws: dict
        "w": shape (3,6)
            the main recurrent weight
        "w_out": shape (3,3)
            the output weight
    state: ndarray (3,)
        the recurrent state
    inpt: ndarray (3,)
        the input token
    """
    full_inpt = np.concatenate([inpt,state],axis=0)
    new_state = np.einsum("nm,m->n", ws["w"], full_inpt)
    new_token = np.einsum("nm,m->n", ws["w_out"], new_state)
    return new_state, new_token

id2word = ["d", "r", "e"]
for targ_num in range(1, 10):
    state = h
    pred = 0
    s = ""
    loop = 0
    while pred<2:
        if loop!=targ_num:
            inpt = embs[pred]
            s += id2word[pred]
        else:
            inpt = embs[2]
            s += "t"
        state,logits = model(ws, state, inpt)
        pred = np.argmax(logits)
        loop += 1
    s += id2word[pred]
    print()
    print(targ_num, "Prediction:", s)

