import argparse
import logging
import torch
import torch.nn as nn
import timeit

from maskrcnn_benchmark.layers import *
from maskrcnn_benchmark.modeling.backbone.resnet_big import StdConv2d
from maskrcnn_benchmark.modeling.backbone.fpn import *
from maskrcnn_benchmark.modeling.rpn.inference import *
from maskrcnn_benchmark.modeling.roi_heads.box_head.inference import PostProcessor
from maskrcnn_benchmark.modeling.rpn.anchor_generator import BufferList


def profile(model, input_size, custom_ops={}, device="cpu", verbose=False, extra_args={}, return_time=False):
    handler_collection = []

    def add_hooks(m):
        if len(list(m.children())) > 0:
            return

        m.register_buffer('total_ops', torch.zeros(1))
        m.register_buffer('total_params', torch.zeros(1))

        for p in m.parameters():
            m.total_params += torch.Tensor([p.numel()])

        m_type = type(m)
        fn = None

        if m_type in custom_ops:
            fn = custom_ops[m_type]
        elif m_type in register_hooks:
            fn = register_hooks[m_type]
        else:
            print("Not implemented for ", m)

        if fn is not None:
            if verbose:
                print("Register FLOP counter for module %s" % str(m))
            handler = m.register_forward_hook(fn)
            handler_collection.append(handler)

    original_device = model.parameters().__next__().device
    training = model.training

    model.eval().to(device)
    model.apply(add_hooks)

    x = torch.zeros(input_size).to(device)
    with torch.no_grad():
        tic = timeit.time.perf_counter()
        model(x, **extra_args)
        toc = timeit.time.perf_counter()
        total_time = toc-tic

    total_ops = 0
    total_params = 0
    for m in model.modules():
        if len(list(m.children())) > 0:  # skip for non-leaf module
            continue
        total_ops += m.total_ops
        total_params += m.total_params

    total_ops = total_ops.item()
    total_params = total_params.item()

    model.train(training).to(original_device)
    for handler in handler_collection:
        handler.remove()

    if return_time:
        return total_ops, total_params, total_time
    else:
        return total_ops, total_params


multiply_adds = 1
def count_conv2d(m, x, y):
    x = x[0]
    cin = m.in_channels
    cout = m.out_channels
    kh, kw = m.kernel_size
    batch_size = x.size()[0]
    out_h = y.size(2)
    out_w = y.size(3)
    # ops per output element
    # kernel_mul = kh * kw * cin
    # kernel_add = kh * kw * cin - 1
    kernel_ops = multiply_adds * kh * kw * cin // m.groups
    bias_ops = 1 if m.bias is not None else 0
    ops_per_element = kernel_ops + bias_ops
    # total ops
    # num_out_elements = y.numel()
    output_elements = batch_size * out_w * out_h * cout
    total_ops = output_elements * ops_per_element
    m.total_ops = torch.Tensor([int(total_ops)])


def count_convtranspose2d(m, x, y):
    x = x[0]
    cin = m.in_channels
    cout = m.out_channels
    kh, kw = m.kernel_size
    batch_size = x.size()[0]
    out_h = y.size(2)
    out_w = y.size(3)
    # ops per output element
    # kernel_mul = kh * kw * cin
    # kernel_add = kh * kw * cin - 1
    kernel_ops = multiply_adds * kh * kw * cin // m.groups
    bias_ops = 1 if m.bias is not None else 0
    ops_per_element = kernel_ops + bias_ops
    # total ops
    # num_out_elements = y.numel()
    # output_elements = batch_size * out_w * out_h * cout
    ops_per_element = m.weight.nelement()
    output_elements = y.nelement()
    total_ops = output_elements * ops_per_element
    m.total_ops = torch.Tensor([int(total_ops)])


def count_bn(m, x, y):
    x = x[0]
    nelements = x.numel()
    # subtract, divide, gamma, beta
    total_ops = 4*nelements
    m.total_ops = torch.Tensor([int(total_ops)])


def count_relu(m, x, y):
    x = x[0]
    nelements = x.numel()
    total_ops = nelements
    m.total_ops = torch.Tensor([int(total_ops)])


def count_softmax(m, x, y):
    x = x[0]
    batch_size, nfeatures = x.size()
    total_exp = nfeatures
    total_add = nfeatures - 1
    total_div = nfeatures
    total_ops = batch_size * (total_exp + total_add + total_div)
    m.total_ops = torch.Tensor([int(total_ops)])


def count_maxpool(m, x, y):
    kernel_ops = torch.prod(torch.Tensor([m.kernel_size]))
    num_elements = y.numel()
    total_ops = kernel_ops * num_elements
    m.total_ops = torch.Tensor([int(total_ops)])


def count_adap_maxpool(m, x, y):
    kernel = torch.Tensor([*(x[0].shape[2:])])//torch.Tensor(list((m.output_size,))).squeeze()
    kernel_ops = torch.prod(kernel)
    num_elements = y.numel()
    total_ops = kernel_ops * num_elements
    m.total_ops = torch.Tensor([int(total_ops)])


def count_avgpool(m, x, y):
    total_add = torch.prod(torch.Tensor([m.kernel_size]))
    total_div = 1
    kernel_ops = total_add + total_div
    num_elements = y.numel()
    total_ops = kernel_ops * num_elements
    m.total_ops = torch.Tensor([int(total_ops)])


def count_adap_avgpool(m, x, y):
    kernel = torch.Tensor([*(x[0].shape[2:])])//torch.Tensor(list((m.output_size,))).squeeze()
    total_add = torch.prod(kernel)
    total_div = 1
    kernel_ops = total_add + total_div
    num_elements = y.numel()
    total_ops = kernel_ops * num_elements
    m.total_ops = torch.Tensor([int(total_ops)])


def count_linear(m, x, y):
    # per output element
    total_mul = m.in_features
    total_add = m.in_features - 1
    num_elements = y.numel()
    total_ops = (total_mul + total_add) * num_elements
    m.total_ops = torch.Tensor([int(total_ops)])


def count_LastLevelMaxPool(m, x, y):
    num_elements = y[-1].numel()
    total_ops = num_elements
    m.total_ops = torch.Tensor([int(total_ops)])


def count_ROIAlign(m, x, y):
    num_elements = y.numel()
    total_ops = num_elements*4
    m.total_ops = torch.Tensor([int(total_ops)])


register_hooks = {
    Scale: None,
    Conv2d: count_conv2d,
    nn.Conv2d: count_conv2d,
    ModulatedDeformConv: count_conv2d,
    StdConv2d: count_conv2d,

    nn.BatchNorm1d: count_bn,
    nn.BatchNorm2d: count_bn,
    nn.BatchNorm3d: count_bn,
    FrozenBatchNorm2d: count_bn,
    nn.GroupNorm: count_bn,
    NaiveSyncBatchNorm2d: count_bn,

    nn.ReLU: count_relu,
    nn.ReLU6: count_relu,
    swish: None,

    nn.ConstantPad2d: None,
    SPPLayer: count_LastLevelMaxPool,
    LastLevelMaxPool: count_LastLevelMaxPool,
    nn.MaxPool1d: count_maxpool,
    nn.MaxPool2d: count_maxpool,
    nn.MaxPool3d: count_maxpool,
    nn.AdaptiveMaxPool1d: count_adap_maxpool,
    nn.AdaptiveMaxPool2d: count_adap_maxpool,
    nn.AdaptiveMaxPool3d: count_adap_maxpool,
    nn.AvgPool1d: count_avgpool,
    nn.AvgPool2d: count_avgpool,
    nn.AvgPool3d: count_avgpool,
    nn.AdaptiveAvgPool1d: count_adap_avgpool,
    nn.AdaptiveAvgPool2d: count_adap_avgpool,
    nn.AdaptiveAvgPool3d: count_adap_avgpool,
    nn.Linear: count_linear,
    nn.Upsample: None,
    nn.Dropout: None,
    nn.Sigmoid: None,
    DropBlock2D: None,

    ROIAlign: count_ROIAlign,
    RPNPostProcessor: None,
    PostProcessor: None,
    BufferList: None,
    RetinaPostProcessor: None,
    FCOSPostProcessor: None,
    ATSSPostProcessor: None,
}