import numpy as np

vocab_size = 50257
seq_len = 1024
default_lmu_method = "fft"


def get_transformer_layers(n_layers):
    # return n_layers
    return 2


def compute_lmu_flops(d, order, method="ss"):
    if method == "ss":
        return 2 * d * (order ** 2 + order)  # using (A, B) matrix updates
    elif method.startswith("rk"):
        rk_order = int(method[2:])
        return 6 * rk_order * d * order
    elif method == "fft":
        return d * (61 * order + 55)


def transformer_params_forward(n_layers, embed_dim):
    """Compute transformer non-embedding params from n_layers and embed_dim."""
    return 12 * n_layers * embed_dim ** 2


def transformer_params_reverse(n_layers, params):
    """Compute transformer embed_dim from n_layers and non-embedding params."""
    return np.sqrt(params / (12 * n_layers))


def comparison_transformer_info(n_layers, non_embed_params):
    tr_layers = get_transformer_layers(n_layers)
    tr_d = transformer_params_reverse(tr_layers, non_embed_params)
    tr_info = compute_gpt(tr_layers, tr_d, verbose=0)
    return tr_layers, tr_d, tr_info


def compute_gpt(n_layers, embed_dim, ff_dim=None, name="", verbose=1):
    ff_dim = 4 * embed_dim if ff_dim is None else ff_dim
    attn_dim = embed_dim

    non_embed_params = 2 * embed_dim * n_layers * (2 * attn_dim + ff_dim)
    embed_params = vocab_size * embed_dim
    params = embed_params + non_embed_params
    # note: param counts do not include biases or layer normalization

    qkv_matmuls = 2 * n_layers * seq_len * (attn_dim + embed_dim)
    non_embed_ops = 2 * non_embed_params + qkv_matmuls
    deembed_ops = 2 * embed_params
    # Tensorflow counts embedding as 0 flops
    ops = deembed_ops + non_embed_ops

    if verbose:
        tr_layers, tr_d, tr_info = comparison_transformer_info(
            n_layers, non_embed_params
        )

        print(
            f"GPT {name}, layers={n_layers}, d={embed_dim}, ff={ff_dim}, attn={attn_dim}"
        )
        print(
            f"  params: {non_embed_params:,} ({params:,} total, {embed_params:,} embed)"
        )
        print(
            f"  flops: {non_embed_ops:,} ("
            f"{non_embed_ops / non_embed_params:0.1f} flops/param, "
            f"{non_embed_ops / tr_info['non_embed_ops']:0.1f} trans. ratio)"
        )
        print(f"  GPT layer dot flops: {qkv_matmuls // n_layers:,}")

        if verbose >= 2:
            print(
                f"  * comparison GPT1: layers={tr_layers}, d={tr_d:0.1f}, "
                f"N={tr_info['non_embed_params']:,.0f}, "
                f"FLOPs={tr_info['non_embed_ops']:,.0f}"
            )

    return dict(
        non_embed_params=non_embed_params,
        embed_params=embed_params,
        params=params,
        non_embed_ops=non_embed_ops,
        deembed_ops=deembed_ops,
        ops=ops,
    )


def compute_lmud(
    n_layers,
    embed_dim,
    order,
    ff_dim=None,
    lmud_order=None,
    n_filters=1,
    initial_dense=None,
    final_dense=True,
    share_filters=True,
    option_2=False,
    order_reduce=False,
    pre_ffn=False,
    post_ffn=False,
    lmu_method=default_lmu_method,
    name="",
    verbose=1,
):
    ff_dim = embed_dim if ff_dim is None else ff_dim
    if lmud_order is None:
        lmud_order = order
    elif lmud_order < 1:
        lmud_order = int(lmud_order * order)

    pre_ffn = 2 if pre_ffn is True else 0 if pre_ffn is False else pre_ffn
    post_ffn = 2 if post_ffn is True else 0 if post_ffn is False else post_ffn
    pre_ffn_dim = 0 if pre_ffn <= 0 else int(pre_ffn * ff_dim)
    post_ffn_dim = 0 if post_ffn <= 0 else int(post_ffn * ff_dim)
    streams = 3 if option_2 else 2

    assert initial_dense is None
    lmu_params = 0
    pre_ffn_params = 0 if pre_ffn_dim == 0 else pre_ffn_dim * (embed_dim + ff_dim)
    if order_reduce:
        dense_l_params = lmud_order * (order + (streams - 1) * lmud_order)
    else:
        dense_l_params = streams * lmud_order * order
    dense_r_params = streams * ff_dim ** 2
    if post_ffn_dim > 0:
        contraction_params = lmud_order * n_filters * (1 if share_filters else ff_dim)
        ffn_params = post_ffn_dim * (embed_dim + ff_dim) * n_filters ** 2
        dense_p_params = contraction_params + ffn_params
    else:
        dense_p_params = lmud_order
    lmud_params = dense_l_params + dense_r_params + dense_p_params
    final_dense_params = embed_dim ** 2 if final_dense else 0
    non_embed_params = n_layers * (pre_ffn_params + lmud_params) + final_dense_params
    embed_params = vocab_size * embed_dim
    params = embed_params + non_embed_params
    # note: param counts do not include biases or layer normalization

    # in LMUD, inner tensor has dimension (ff_dim, order)
    if order_reduce:
        dense_l = 2 * ff_dim * lmud_order * (order + (streams - 1) * lmud_order)
    else:
        dense_l = 2 * streams * ff_dim * order * lmud_order
    dense_r = 2 * streams * ff_dim ** 2 * lmud_order
    matmul_gg = 2 * ff_dim * lmud_order ** 2
    matmul_gx = 2 * ff_dim * lmud_order ** 2 + lmud_order * ff_dim
    if post_ffn_dim > 0:
        contraction = 2 * ff_dim * lmud_order * n_filters
        ffn = 2 * ffn_params
        dense_p = contraction + ffn
    else:
        dense_p = 2 * ff_dim * lmud_order
    lmud = dense_l + dense_r + matmul_gg + matmul_gx + dense_p

    pre_ffn_ops = 2 * pre_ffn_params
    lmu = compute_lmu_flops(ff_dim, order, method=lmu_method)

    final_dense = 2 * final_dense_params

    non_embed_ops = n_layers * (lmu + lmud + pre_ffn_ops) + final_dense
    deembed_ops = 2 * embed_params
    # Tensorflow counts embedding as 0 flops
    ops = deembed_ops + non_embed_ops

    if verbose:
        tr_layers, tr_d, tr_info = comparison_transformer_info(
            n_layers, non_embed_params
        )

        print(
            f"LMUD {name}, layers={n_layers}, d={embed_dim}, ff={ff_dim}, q={order}, "
            f"q'={lmud_order}, filt={n_filters}, share={share_filters}, "
            f"pre_ffn={pre_ffn}, ffn={post_ffn}"
        )
        print(
            f"  params: {non_embed_params:,} ({params:,} total, {embed_params:,} embed,"
            f" {n_layers * lmu_params:,} lmu)"
        )
        print(
            f"  flops: {non_embed_ops:,} ("
            f"{non_embed_ops / non_embed_params:0.1f} flops/param, "
            f"{non_embed_ops / tr_info['non_embed_ops']:0.1f} trans. ratio)"
        )

        if verbose >= 2:
            print(f"  LMU: {lmu_params:,} p, {lmu:,} flops")
            if pre_ffn_params > 0:
                print(f"  Pre FFN: {pre_ffn_params:,} p, {pre_ffn_ops:,} flops")
            print(f"  Dense L (order): {dense_l_params:,} p, {dense_l:,} flops")
            print(f"  Dense R (dims): {dense_r_params:,} p, {dense_r:,} flops")
            print(f"  Matmuls: {0:,} p, {matmul_gg + matmul_gx:,} flops")
            print(f"  Dense P: {dense_p_params:,} p, {dense_p:,} flops")
            print(f"  LMUD layer total: {lmud_params:,} p, {lmud:,} flops")
            print(
                f"  * comparison GPT1: layers={tr_layers}, d={tr_d:0.1f}, "
                f"N={tr_info['non_embed_params']:,.0f}, "
                f"FLOPs={tr_info['non_embed_ops']:,.0f}"
            )
        else:
            print(f"  LMU layer flops: {lmu:,}")
            print(
                f"  LMUD layer flops: {lmud:,} ({lmud / lmud_params:0.1f} flops/param)"
            )


def compute_lmumlp(
    n_layers,
    embed_dim,
    order,
    ff_dim=None,
    n_filters=1,
    share_filters=True,
    lmu_method=default_lmu_method,
    name="",
    verbose=1,
):
    ff_dim = embed_dim if ff_dim is None else ff_dim

    lmu_params = 0
    contraction_params = n_filters * order * (1 if share_filters else ff_dim)
    dense_params = embed_dim * ff_dim * (n_filters + 1)
    lmu_mlp_params = contraction_params + dense_params
    non_embed_params = n_layers * lmu_mlp_params
    embed_params = vocab_size * embed_dim
    params = embed_params + non_embed_params
    # note: param counts do not include biases or layer normalization

    pre_dense = 2 * embed_dim * ff_dim
    contraction_ops = 2 * ff_dim * order * n_filters
    elmul_xx = ff_dim * n_filters
    post_dense = 2 * ff_dim * n_filters * embed_dim
    lmu_mlp = pre_dense + contraction_ops + post_dense + elmul_xx

    lmu = compute_lmu_flops(embed_dim, order, method=lmu_method)
    non_embed_ops = n_layers * (lmu + lmu_mlp)
    deembed_ops = 2 * embed_params
    # Tensorflow counts embedding as 0 flops
    ops = deembed_ops + non_embed_ops

    if verbose:
        tr_layers, tr_d, tr_info = comparison_transformer_info(
            n_layers, non_embed_params
        )

        print(
            f"LMU-MLP {name}, layers={n_layers}, d={embed_dim}, ff={ff_dim}, q={order}, "
            f"filt={n_filters}, share={share_filters}"
        )
        print(
            f"  params: {non_embed_params:,} ({params:,} total, {embed_params:,} embed, {n_layers * lmu_params:,} lmu)"
        )
        print(
            f"  flops: {non_embed_ops:,} ("
            f"{non_embed_ops / non_embed_params:0.1f} flops/param, "
            f"{non_embed_ops / tr_info['non_embed_ops']:0.1f} trans. ratio)"
        )
        print(f"  LMU layer flops: {lmu:,}")
        print(
            f"  other layer flops: {lmu_mlp:,} ({lmu_mlp / lmu_mlp_params:0.1f} flops/param)"
        )

        if verbose >= 2:
            print(
                f"  * comparison GPT1: layers={tr_layers}, d={tr_d:0.1f}, "
                f"N={tr_info['non_embed_params']:,.0f}, "
                f"FLOPs={tr_info['non_embed_ops']:,.0f}"
            )


models50 = [
    dict(
        kind="gpt",
        name="row 66, 56k",
        embed_dim=64,
        ff_dim=64,
        n_layers=2,
    ),
    dict(
        kind="lmumlp",
        name="row 98, 59k",
        n_layers=4,
        embed_dim=48,
        ff_dim=96,
        order=256,
        n_filters=2,
        share_filters=True,
    ),
    dict(
        kind="lmud",
        name="row 109, 56k",
        n_layers=3,
        embed_dim=48,
        order=80,
    ),
    dict(
        kind="lmud",
        name="row 119, 56k",
        n_layers=2,
        embed_dim=48,
        order=138,
        lmud_order=0.333,
        post_ffn=True,
    ),
    dict(
        kind="lmud",
        name="row 136/7, 56k",
        embed_dim=48,
        order=138,
        n_layers=3,
        lmud_order=0.1,
        post_ffn=True,
    ),
    dict(
        kind="lmud",
        name="row 138, 55k",
        embed_dim=48,
        ff_dim=36,
        order=64,
        n_layers=3,
        lmud_order=0.25,
        pre_ffn=2,
        # this model used post-FFN dim based off embed_dim. Setting this to 2.667
        # achieves the same effect with the new code that uses ff_dim.
        post_ffn=2.667,
        order_reduce=True,
        option_2=False,
        final_dense=False,
    ),
    dict(
        kind="lmud",
        name="row 139, 55k",
        embed_dim=48,
        ff_dim=36,
        order=64,
        n_layers=3,
        lmud_order=0.25,
        pre_ffn=2,
        post_ffn=2,
        order_reduce=True,
        option_2=True,
        final_dense=False,
    ),
]

models300 = [
    dict(
        kind="gpt",
        # (note 96 has different n_heads, but this does not effect params/flops)
        name="row 96/97, 336k",
        embed_dim=96,
        n_layers=3,
    ),
    dict(
        kind="lmud",
        name="row 115, 302k",
        n_layers=4,
        embed_dim=112,
        order=152,
    ),
    dict(
        kind="lmud",
        name="row 133, 301k",
        embed_dim=112,
        order=173,
        n_layers=3,
        lmud_order=0.333,
        post_ffn=True,
    ),
    dict(
        kind="lmumlp",
        name="row 135, 307k",
        n_layers=4,
        embed_dim=112,
        order=174,
        n_filters=5,
        share_filters=True,
    ),
]

models_paper = [
    dict(
        kind="lmud",
        name="row 119, 56k",
        n_layers=2,
        embed_dim=48,
        order=138,
        lmud_order=0.333,
        post_ffn=True,
    ),
    dict(
        kind="lmud",
        name="row 136/7, 56k",
        embed_dim=48,
        order=138,
        n_layers=3,
        lmud_order=0.1,
        post_ffn=True,
    ),
    dict(
        kind="lmud",
        name="row 127, 102k",
        n_layers=3,
        embed_dim=65,
        order=100,
        lmud_order=0.333,
        post_ffn=True,
    ),
    dict(
        kind="lmud",
        name="row 128, 200k",
        n_layers=3,
        embed_dim=91,
        order=143,
        lmud_order=0.333,
        post_ffn=True,
    ),
    dict(
        kind="lmud",
        name="row 131, 307k",
        n_layers=3,
        embed_dim=112,
        order=135,
        lmud_order=0.25,
        post_ffn=True,
        option_2=True,
        final_dense=False,
    ),
    dict(
        kind="lmud",
        name="row 133, 301k",
        embed_dim=112,
        order=173,
        n_layers=3,
        lmud_order=0.333,
        n_filters=1,
        post_ffn=True,
    ),
]

# models = models300
models = models_paper

for model in models:
    kind = model.pop("kind")
    model.setdefault("verbose", 2)
    if kind == "gpt":
        compute_gpt(**model)
    elif kind == "lmud":
        compute_lmud(**model)
    elif kind == "lmumlp":
        compute_lmumlp(**model)
    else:
        raise NotImplementedError(kind)
