"""
Modified from https://github.com/google-research/electra/blob/master/flops_computation.py.
"""

import collections

# random number, >=, multiply activations by dropout mask, multiply activations
# by correction (1 / (1 - dropout_rate))
DROPOUT_FLOPS = 4

# compute mean activation (sum), computate variance of activation
# (square and sum), bias (add), scale (multiply)
LAYER_NORM_FLOPS = 5

# GELU: 0.5 * x * (1 + tanh(sqrt(2 / np.pi) * (x + 0.044715 * pow(x, 3))))
ACTIVATION_FLOPS = 8

# max/substract (for stability), exp, sum, divide
SOFTMAX_FLOPS = 5

class TransformerHparams(object):
  """Computes the train/inference FLOPs for transformers."""
  def __init__(self, h, l, s=512, heads=12, mlp_ratio=4, pooling=False, s_next=512, embedding=True, patch_stride=16):
    self.h = h  # hidden size
    self.l = l  # number of layers
    self.s = s  # sequence length
    self.kqv = h # attn proj sizes
    self.heads = heads
    self.mlp_ratio = mlp_ratio
    self.pooling = pooling
    self.s_next = s_next
    self.embedding = embedding
    self.patch_stride = patch_stride

  def get_block_flops(self):
    """Get the forward-pass FLOPs for a single transformer block."""
    attn_mul = 1
    block_flops = dict(
        kqv=3 * self.h * self.kqv * attn_mul,
        kqv_bias=3 * self.kqv * attn_mul,
        attention_scores=self.kqv * self.s * attn_mul,
        attn_softmax=SOFTMAX_FLOPS * self.s * self.heads * attn_mul,
        #attention_dropout=DROPOUT_FLOPS * self.s * self.heads * attn_mul,
        attention_scale=self.s * self.heads * attn_mul,
        attention_weighted_avg_values=self.h * self.s * attn_mul,
        attn_output=self.h * self.h * attn_mul,
        attn_output_bias=self.h * attn_mul,
        #attn_output_dropout=DROPOUT_FLOPS * self.h * attn_mul,
        attn_output_residual=self.h * attn_mul,
        attn_output_layer_norm=LAYER_NORM_FLOPS * attn_mul,
        mlp=(self.h * self.h * self.mlp_ratio + self.h * self.mlp_ratio * self.h) * attn_mul,
        mlp_bias=(self.h * self.mlp_ratio + self.h) * attn_mul,
        mlp_act=ACTIVATION_FLOPS * (self.h * self.mlp_ratio + self.h),
        mlp_residule=self.h * attn_mul
    )
    #for k, v in block_flops.items():
    #  print(k, v * self.s * self.l / 1e9)
    return sum(block_flops.values()) * self.s

  def get_pooling_flops(self):
    """Get the forward-pass FLOPs for a single transformer block."""
    if not self.pooling or self.s_next == 0:
      return 0
    else:
      attn_mul = 1
      block_flops = dict(
          kqv=3 * self.h * self.kqv * attn_mul,
          kqv_bias=3 * self.kqv * attn_mul,
          attention_scores=self.kqv * self.s * attn_mul,
          attn_softmax=SOFTMAX_FLOPS * self.s * self.heads * attn_mul,
          #attention_dropout=DROPOUT_FLOPS * self.s * self.heads * attn_mul,
          attention_scale=self.s * self.heads * attn_mul,
          attention_weighted_avg_values=self.h * self.s * attn_mul,
          attn_output=self.h * self.h * attn_mul,
          attn_output_bias=self.h * attn_mul,
          attn_output_residual=self.h * attn_mul,
          attn_output_layer_norm=LAYER_NORM_FLOPS * attn_mul,
          mlp=self.h * self.h * 4 * attn_mul + self.h * 4 * self.h * attn_mul,
          mlp_dropout=DROPOUT_FLOPS * self.h * 4 * attn_mul + DROPOUT_FLOPS * self.h * attn_mul,
          mlp_layer_norm=LAYER_NORM_FLOPS * attn_mul + LAYER_NORM_FLOPS * 4 * attn_mul,
          mlp_bias=self.h * attn_mul + self.h * 4 * attn_mul,
          mlp_residual=self.h * attn_mul,
          fps_dist=self.h * self.s,
          fps_iteration=self.s * self.s_next,
      )
      #for k, v in block_flops.items():
      #  print(k, v * self.s * self.l / 1e9)
      return sum(block_flops.values()) * self.s

  def get_embedding_flops(self):
    """Get the forward-pass FLOPs for a single transformer block."""
    if not self.embedding:
      k = 16
      p = 14
      mlp_ops = dict(mlp_emb=3 * self.h * k * k * p * p)
      return sum(mlp_ops.values())
    else:
      if self.patch_stride == 16:
        p = 14
      else:
        p = 28
      conv_flops = dict(
          conv1=3 * 3 * 3 * self.h // 8 * 112 * 112,
          conv2=3 * 3 * self.h // 8 * self.h // 4 * 56 * 56,
          conv3=3 * 3 * self.h // 4 * self.h // 2 * 28 * 28,
          conv4=3 * 3 * self.h // 2 * self.h // 1 * p * p,
          conv5=1 * 1 * self.h * self.h * p * p,
      )
      return sum(conv_flops.values())

  def get_binary_classification_flops(self):
    """Get the output head"""
    # Ignore output head.
    return 0

  def get_infer_flops(self):
    """Get the FLOPs for running inference with the transformer on a
    classification task."""
    return ((self.l * self.get_block_flops()) +
            self.get_pooling_flops() +
            self.get_embedding_flops() +
            self.get_binary_classification_flops()) / 1e9

MODEL_FLOPS = collections.OrderedDict([ 
  ("ViT-B/16(MLP): ", TransformerHparams(768, 12, 197, 12, 4, False, None, False).get_infer_flops()),
  ("ViT-B/16(CONV): ", TransformerHparams(768, 11, 197, 12, 4, False, None, True, 16).get_infer_flops()),
  ("ViT-S/16(CONV): ", TransformerHparams(384, 11, 197, 12, 4, False, None, True, 16).get_infer_flops()),
  ("CAST-S: ", sum([TransformerHparams(384, 3, 197, 12, 4, True, 64, True, 8).get_infer_flops(),
                    TransformerHparams(384, 3, 65, 12, 4, True, 32, False).get_infer_flops(),
                    TransformerHparams(384, 3, 33, 12, 4, True, 16, False).get_infer_flops(),
                    TransformerHparams(384, 2, 17, 12, 4, True, 0, False).get_infer_flops()])),
  ("CAST-SD: ", sum([TransformerHparams(384, 6, 197, 12, 4, True, 64, True, 8).get_infer_flops(),
                    TransformerHparams(384, 3, 65, 12, 4, True, 32, False).get_infer_flops(),
                    TransformerHparams(384, 3, 33, 12, 4, True, 16, False).get_infer_flops(),
                    TransformerHparams(384, 3, 17, 12, 4, True, 0, False).get_infer_flops()])),
  ("CAST-B: ", sum([TransformerHparams(768, 3, 197, 12, 4, True, 64, True, 8).get_infer_flops(),
                    TransformerHparams(768, 3, 65, 12, 4, True, 32, False).get_infer_flops(),
                    TransformerHparams(768, 3, 33, 12, 4, True, 16, False).get_infer_flops(),
                    TransformerHparams(768, 2, 17, 12, 4, True, 0, False).get_infer_flops()])),
  ("CAST-BD: ", sum([TransformerHparams(768, 6, 197, 12, 4, True, 64, True, 8).get_infer_flops(),
                    TransformerHparams(768, 3, 65, 12, 4, True, 32, False).get_infer_flops(),
                    TransformerHparams(768, 3, 33, 12, 4, True, 16, False).get_infer_flops(),
                    TransformerHparams(768, 3, 17, 12, 4, True, 0, False).get_infer_flops()])),
])

def main():
  for k, v in MODEL_FLOPS.items():
    print(k, v)

if __name__ == "__main__":
  main()
