import numpy as np

# model = "6k"
model = "300k"

if model == "6k":
    # loss: 5.913378843203889
    target_total = 827714
    target_nonembed = 6143
    n_layer = 2
    d_model = 16
elif model == "300k":
    # loss: 4.321427838772197
    target_total = 5262559
    target_nonembed = 331717
    n_layer = 3

    # d_model = np.arange(64, 128 + 1)[:, None, None]
    # d_attn = np.arange(64, 128 + 1)[None, :, None]
    # d_ff = np.arange(64, 1024 + 1)[None, None, :]

    d_model = 96
    d_attn = d_model
    d_ff = 4 * d_model

n_vocab = 50257
n_ctx = 1024

# predict numbers of model parameters for this architecture
embed = (n_vocab + n_ctx) * d_model
nonembed = 2 * d_model * n_layer * (2 * d_attn + d_ff)
other = (
    n_layer * (4 * d_model)  # attention biases
    + n_layer * (d_ff + d_model)  # feedforward biases
    + 2 * 2 * n_layer * d_model  # LayerNormalization biases in blocks
)
total = embed + nonembed + other

if any(np.asarray(x).size > 1 for x in (d_model, d_attn, d_ff)):
    # loss = np.abs(total - target_total) + np.abs(nonembed - target_nonembed)
    loss = 0.001 * np.abs(total - target_total) + np.abs(nonembed - target_nonembed)
    ind = np.argmin(loss)
    i, j, k = np.unravel_index(ind, (d_model.size, d_attn.size, d_ff.size))

    get_ijk = lambda x: x[
        0 if x.shape[0] == 1 else i,
        0 if x.shape[1] == 1 else j,
        0 if x.shape[2] == 1 else k,
    ]
    embed = get_ijk(embed)
    nonembed = get_ijk(nonembed)
    other = get_ijk(other)
    total = get_ijk(total)
    d_model = d_model[i, 0, 0]
    d_attn = d_attn[0, j, 0]
    d_ff = d_ff[0, 0, k]

target_other = target_total - target_nonembed - embed

print(f"Model: n_layer {n_layer}, d_model {d_model}, d_attn {d_attn}, d_ff {d_ff}")
print(
    f"Target: total {target_total}, embed {embed}, nonembed {target_nonembed}, other {target_other}"
)
print(f"Predict: total {total}, embed {embed}, nonembed {nonembed}, other {other}")
