"""
Modified from https://github.com/google-research/electra/blob/master/flops_computation.py.
"""

import collections

# random number, >=, multiply activations by dropout mask, multiply activations
# by correction (1 / (1 - dropout_rate))
DROPOUT_FLOPS = 4

# compute mean activation (sum), computate variance of activation
# (square and sum), bias (add), scale (multiply)
LAYER_NORM_FLOPS = 5

# GELU: 0.5 * x * (1 + tanh(sqrt(2 / np.pi) * (x + 0.044715 * pow(x, 3))))
ACTIVATION_FLOPS = 8

# max/substract (for stability), exp, sum, divide
SOFTMAX_FLOPS = 5

class TransformerHparams(object):
  """Computes the train/inference FLOPs for transformers."""
  def __init__(self, h, l, s=512, heads=12, mlp_ratio=4, pooling=False, s_next=512, embedding=True, patch_stride=16, hc=False, kmeans=False, upsamp=False):
    self.h = h  # hidden size
    self.l = l  # number of layers
    self.s = s  # sequence length
    self.kqv = h # attn proj sizes
    self.heads = heads
    self.mlp_ratio = mlp_ratio
    self.pooling = pooling
    self.s_next = s_next
    self.embedding = embedding
    self.patch_stride = patch_stride
    self.hc = hc
    self.kmeans = kmeans
    self.upsamp = upsamp

  def get_block_flops(self):
    """Get the forward-pass FLOPs for a single transformer block."""
    attn_mul = 1
    block_flops = dict(
        kqv=3 * self.h * self.kqv * attn_mul,
        kqv_bias=3 * self.kqv * attn_mul,
        attention_scores=self.kqv * self.s * attn_mul,
        attn_softmax=SOFTMAX_FLOPS * self.s * self.heads * attn_mul,
        #attention_dropout=DROPOUT_FLOPS * self.s * self.heads * attn_mul,
        attention_scale=self.s * self.heads * attn_mul,
        attention_weighted_avg_values=self.h * self.s * attn_mul,
        attn_output=self.h * self.h * attn_mul,
        attn_output_bias=self.h * attn_mul,
        #attn_output_dropout=DROPOUT_FLOPS * self.h * attn_mul,
        attn_output_residual=self.h * attn_mul,
        attn_output_layer_norm=LAYER_NORM_FLOPS * attn_mul,
        mlp=(self.h * self.h * self.mlp_ratio + self.h * self.mlp_ratio * self.h) * attn_mul,
        mlp_bias=(self.h * self.mlp_ratio + self.h) * attn_mul,
        mlp_act=ACTIVATION_FLOPS * (self.h * self.mlp_ratio + self.h),
        mlp_residule=self.h * attn_mul
    )
    #for k, v in block_flops.items():
    #  print(k, v * self.s * self.l / 1e9)
    return sum(block_flops.values()) * self.s

  def get_pooling_flops(self):
    """Get the forward-pass FLOPs for a single transformer block."""
    if not self.pooling:
      return 0
    else:
      attn_mul = 1
      block_flops = dict(
          kqv=3 * self.h * self.kqv * attn_mul,
          kqv_bias=3 * self.kqv * attn_mul,
          attention_scores=self.kqv * self.s * attn_mul,
          attn_softmax=SOFTMAX_FLOPS * self.s * self.heads * attn_mul,
          #attention_dropout=DROPOUT_FLOPS * self.s * self.heads * attn_mul,
          attention_scale=self.s * self.heads * attn_mul,
          attention_weighted_avg_values=self.h * self.s * attn_mul,
          attn_output=self.h * self.h * attn_mul,
          attn_output_bias=self.h * attn_mul,
          attn_output_residual=self.h * attn_mul,
          attn_output_layer_norm=LAYER_NORM_FLOPS * attn_mul,
          mlp=self.h * self.h * attn_mul,
          mlp_dropout=DROPOUT_FLOPS * self.h * attn_mul,
          mlp_layer_norm=LAYER_NORM_FLOPS * attn_mul,
          mlp_bias=self.h * attn_mul,
          mlp_residual=self.h * attn_mul,
          fps_dist=self.h * self.s,
          fps_iteration=self.s * self.s_next,
      )
      #for k, v in block_flops.items():
      #  print(k, v * self.s * self.l / 1e9)
      return sum(block_flops.values()) * self.s

  def get_embedding_flops(self):
    """Get the forward-pass FLOPs for a single transformer block."""
    if not self.embedding:
      return 0
    else:
      if self.patch_stride == 16:
        p = 32
      else:
        p = 64
      conv_flops = dict(
          conv1=3 * 3 * 3 * self.h // 8 * 256 * 256,
          conv2=3 * 3 * self.h // 8 * self.h // 4 * 128 * 128,
          conv3=3 * 3 * self.h // 4 * self.h // 2 * 64 * 64,
          conv4=3 * 3 * self.h // 2 * self.h // 1 * p * p,
          conv5=1 * 1 * self.h * self.h * p * p,
      )
      return sum(conv_flops.values())

  def get_hc_flops(self):
    if not self.hc:
      return 0
    else:
      hc_flops = dict(
          fc=(self.h * 4) * (self.h * 1) * self.s,
          gather=self.h  * self.s * 4,
      )
      return sum(hc_flops.values())

  def get_kmeans_flops(self):
    if not self.kmeans:
      return 0
    else:
      kmeans_flops = dict(
          upsamp=19 * self.h * 512 * 512, # Computational Foundations of Image Interpolation Algorithms.
          fps_dist=self.h * 512 * 36 * 10,
      )
      return sum(kmeans_flops.values())

  def get_upsamp_patch_flops(self):
    if not self.upsamp:
      return 0
    else:
      upsamp_flops = dict(
          emb_bilinear_upsamp=19 * self.h * 512 * 512, # Computational Foundations of Image Interpolation Algorithms.
          pos_bilinear_upsamp=19 * self.h * 512 * 512, # Computational Foundations of Image Interpolation Algorithms.
          #bicubic_upsamp=79 * self.h * 512 * 512,
      )
      print(sum(upsamp_flops.values()) / 1e9)
      return sum(upsamp_flops.values())


  def get_binary_classification_flops(self):
    """Get the output head"""
    # Ignore output head.
    return 0

  def get_infer_flops(self):
    """Get the FLOPs for running inference with the transformer on a
    classification task."""
    return ((self.l * self.get_block_flops()) +
            self.get_pooling_flops() +
            self.get_embedding_flops() +
            self.get_binary_classification_flops() + 
            self.get_hc_flops() + 
            self.get_kmeans_flops() + 
            self.get_upsamp_patch_flops()) / 1e9

MODEL_FLOPS = collections.OrderedDict([ 
  ("ViT-S/16(CONV): ", TransformerHparams(384, 11, 1025, 12, 4, False, None, True, 16, False, True, False).get_infer_flops()),
  ("CAST-384: ", sum([TransformerHparams(384, 3, 1025, 12, 4, True, 320, True, 8, True, False, True).get_infer_flops(),
                  TransformerHparams(384, 3, 331, 12, 4, True, 160, False, False, False, False).get_infer_flops(),
                  TransformerHparams(384, 3, 161, 12, 4, True, 80, False, False, False, False).get_infer_flops(),
                  TransformerHparams(384, 2, 81, 12, 4, True, 40, False, False, False, False).get_infer_flops()])),
])

def main():
  for k, v in MODEL_FLOPS.items():
    print(k, v)

if __name__ == "__main__":
  main()

