# https://github.com/google-research/electra/blob/master/flops_computation.py

"""Computes the flops needed for training/running transformer networks."""

import collections

# We checked this code with TensorFlow"s FLOPs counting, although we had to
# correct for this issue: https://github.com/tensorflow/tensorflow/issues/22071
# Assumptions going into the FLOPs counting
#   - An "operation" is a mathematical operation, not a machine instruction. So
#     an "exp" takes one opp like and add, even though in practice an exp
#     might be slower. This is not too bad an assumption because
#     matrix-multiplies dominate the compute for most models, so minor details
#     about activation functions don"t matter too much. Similarly, we count
#     matrix-multiplies as 2*m*n flops instead of m*n, as one might if
#     if considering fused multiply-add ops.
#   - Backward pass takes the same number of FLOPs as forward pass. No exactly
#     right (e.g., for softmax cross entropy loss the backward pass is faster).
#     Importantly, it really is the same for matrix-multiplies, which is most of
#     the compute anyway.
#   - We assume "dense" embedding lookups (i.e., multiplication by a one-hot
#     vector). On some hardware accelerators, these dense operations are
#     actually faster than sparse lookups.
# Please open a github issue if you spot a problem with this code!

# I am not sure if the below constants are 100% right, but they are only applied
# to O(hidden_size) activations, which is generally a lot less compute than the
# matrix-multiplies, which are O(hidden_size^2), so they don't affect the total
# number of FLOPs much.

# random number, >=, multiply activations by dropout mask, multiply activations
# by correction (1 / (1 - dropout_rate))
DROPOUT_FLOPS = 4

# compute mean activation (sum), computate variance of activation
# (square and sum), bias (add), scale (multiply)
LAYER_NORM_FLOPS = 5

# GELU: 0.5 * x * (1 + tanh(sqrt(2 / np.pi) * (x + 0.044715 * pow(x, 3))))
ACTIVATION_FLOPS = 8

# max/substract (for stability), exp, sum, divide
SOFTMAX_FLOPS = 5


class TransformerHparams(object):
  """Computes the train/inference FLOPs for transformers."""

  def __init__(self, h, l, s=512, v=30522, e=None, i=None, heads=None,
      head_size=None, output_frac=0.15625, sparse_embed_lookup=False,
      decoder=False):
    self.h = h  # hidden size
    self.l = l  # number of layers
    self.s = s  # sequence length
    self.v = v  # vocab size
    self.e = h if e is None else e  # embedding size
    self.i = h * 4 if i is None else i  # intermediate size
    self.kqv = h if head_size is None else head_size * heads  # attn proj sizes
    self.heads = max(h // 64, 1) if heads is None else heads  # attention heads
    self.output_frac = output_frac  # percent of tokens using an output softmax
    self.sparse_embed_lookup = sparse_embed_lookup  # sparse embedding lookups
    self.decoder = decoder  # decoder has extra attn to encoder states

  def get_block_flops(self):
    """Get the forward-pass FLOPs for a single transformer block."""
    attn_mul = 2 if self.decoder else 1
    block_flops = dict(
        kqv=3 * 2 * self.h * self.kqv * attn_mul,
        kqv_bias=3 * self.kqv * attn_mul,
        attention_scores=2 * self.kqv * self.s * attn_mul,
        attn_softmax=SOFTMAX_FLOPS * self.s * self.heads * attn_mul,
        attention_dropout=DROPOUT_FLOPS * self.s * self.heads * attn_mul,
        attention_scale=self.s * self.heads * attn_mul,
        attention_weighted_avg_values=2 * self.h * self.s * attn_mul,
        attn_output=2 * self.h * self.h * attn_mul,
        attn_output_bias=self.h * attn_mul,
        attn_output_dropout=DROPOUT_FLOPS * self.h * attn_mul,
        attn_output_residual=self.h * attn_mul,
        attn_output_layer_norm=LAYER_NORM_FLOPS * attn_mul,
        intermediate=2 * self.h * self.i,
        intermediate_act=ACTIVATION_FLOPS * self.i,
        intermediate_bias=self.i,
        output=2 * self.h * self.i,
        output_bias=self.h,
        output_dropout=DROPOUT_FLOPS * self.h,
        output_residual=self.h,
        output_layer_norm=LAYER_NORM_FLOPS * self.h,
    )
    return sum(block_flops.values()) * self.s

  def get_embedding_flops(self, output=False):
    """Get the forward-pass FLOPs the transformer inputs or output softmax."""
    embedding_flops = {}
    if output or (not self.sparse_embed_lookup):
      embedding_flops["main_multiply"] = 2 * self.e * self.v
    # input embedding post-processing
    if not output:
      embedding_flops.update(dict(
          tok_type_and_position=2 * self.e * (self.s + 2),
          add_tok_type_and_position=2 * self.e,
          emb_layer_norm=LAYER_NORM_FLOPS * self.e,
          emb_dropout=DROPOUT_FLOPS * self.e
      ))
    # projection layer if e != h
    if self.e != self.h or output:
      embedding_flops.update(dict(
          hidden_kernel=2 * self.h * self.e,
          hidden_bias=self.e if output else self.h
      ))
      # extra hidden layer and output softmax
      if output:
        embedding_flops.update(dict(
            hidden_activation=ACTIVATION_FLOPS * self.e,
            hidden_layernorm=LAYER_NORM_FLOPS * self.e,
            output_softmax=SOFTMAX_FLOPS * self.v,
            output_target_word=2 * self.v
        ))
        return self.output_frac * sum(embedding_flops.values()) * self.s
    return sum(embedding_flops.values()) * self.s

  def get_binary_classification_flops(self):
    classification_flops = dict(
        hidden=2 * self.h * self.h,
        hidden_bias=self.h,
        hidden_act=ACTIVATION_FLOPS * self.h,
        logits=2 * self.h
    )
    return sum(classification_flops.values()) * self.s

  def get_train_flops(self, batch_size, train_steps, discriminator=False):
    """Get the FLOPs for pre-training the transformer."""
    # 2* for forward/backward pass
    return 2 * batch_size * train_steps * (
        (self.l * self.get_block_flops()) +
        self.get_embedding_flops(output=False) +
        (self.get_binary_classification_flops() if discriminator else
         self.get_embedding_flops(output=True))
    )

  def get_infer_flops(self):
    """Get the FLOPs for running inference with the transformer on a
    classification task."""
    return ((self.l * self.get_block_flops()) +
            self.get_embedding_flops(output=False) +
            self.get_binary_classification_flops())


def get_electra_train_flops(
    h_d, l_d, h_g, l_g, batch_size, train_steps, tied_embeddings,
    e=None, s=512, output_frac=0.15625):
  """Get the FLOPs needed for  pre-training ELECTRA."""
  if e is None:
    e = h_d
  disc = TransformerHparams(
      h_d, l_d, s=s, e=e,
      output_frac=output_frac).get_train_flops(batch_size, train_steps, True)
  gen = TransformerHparams(
      h_g, l_g, s=s, e=e if tied_embeddings else None,
      output_frac=output_frac).get_train_flops(batch_size, train_steps)
  return disc + gen


MODEL_FLOPS = collections.OrderedDict([
    # These runtimes were computed with tensorflow FLOPs counting instead of the
    # script, as the neural architectures are quite different.
    # 768648884 words in LM1b benchmark, 10 epochs with batch size 20,
    # seq length 128, 568093262680 FLOPs per example.
    ("elmo", 2 * 10 * 768648884 * 568093262680 / (20.0 * 128)),
    # 15064773691518 is FLOPs for forward pass on 32 examples.
    # Therefore 2 * steps * batch_size * 15064773691518 / 32 is XLNet compute
    ("xlnet", 2 * 500000 * 8192 * 15064773691518 / 32.0),

    # Runtimes computed with the script
    ("gpt", TransformerHparams(768, 12, v=40000, output_frac=1.0).get_train_flops(
        128, 960800)),
    ("bert_small", TransformerHparams(256, 12, e=128, s=128).get_train_flops(128, 1.45e6)),
    ("bert_base", TransformerHparams(768, 12).get_train_flops(256, 1e6)),
    ("bert_base_infer", TransformerHparams(768, 12).get_infer_flops()),
    ("bert_large", TransformerHparams(1024, 24).get_train_flops(256, 1e6)),
    ("electra_small", get_electra_train_flops(256, 12, 64, 12, 128, 1e6, True, s=128, e=128)),
    ("electra_base", get_electra_train_flops(768, 12, 256, 12, 256, 766000, True)),
    ("electra_400k", get_electra_train_flops(1024, 24, 256, 24, 2048, 400000, True)),
    ("electra_1.75M", get_electra_train_flops(1024, 24, 256, 24, 2048, 1750000, True)),

    # RoBERTa, ALBERT, and T5 have  minor architectural differences from
    # BERT/ELECTRA, but I believe they don't significantly effect the runtime,
    # so we use this script for those models as well.
    ("roberta", TransformerHparams(1024, 24, v=50265).get_train_flops(8000, 500000)),
    ("albert", TransformerHparams(4096, 12, v=30000, e=128).get_train_flops(
        4096, 1.5e6)),
    ("t5_11b", TransformerHparams(
        1024,  # hidden size
        24,  # layers
        v=32000,  # vocab size
        i=65536,  # ff intermediate hidden size
        heads=128, head_size=128,  # heads/head size
        output_frac=0.0  # encoder has no output softmax
    ).get_train_flops(2048, 1e6) +  # 1M steps with batch size 2048
     TransformerHparams(
         1024,
         24,
         v=32000,
         i=65536,
         heads=128, head_size=128,
         output_frac=1.0,  # decoder has output softmax for all positions
         decoder=True
     ).get_train_flops(2048, 1e6))
])


def main():
  for k, v in MODEL_FLOPS.items():
    print(k, v)


if __name__ == "__main__":
  main()
