import numpy as np

import torch
import torchvision
import torch.nn as nn
from torch.autograd import Variable
import pdb
import time

list_conv = []
list_bn = []
list_relu = []
list_linear=[]
list_pooling=[]
list_upsample=[]

def reset_lists():
    list_conv = []
    list_bn = []
    list_relu = []
    list_linear=[]
    list_pooling=[]
    list_upsample=[]
 
def register_handles(model, multiply_adds = True):
    reset_lists()
    def conv_hook(self, input, output):
        batch_size, input_channels, input_height, input_width = input[0].size()
        output_channels, output_height, output_width = output[0].size()

        kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups)
        bias_ops = 1 if self.bias is not None else 0

        params = output_channels * (kernel_ops + bias_ops)
        flops = (kernel_ops * (2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size

        list_conv.append(flops)

    def bn_hook(self, input, output):
        list_bn.append(input[0].nelement() * 2)

    def relu_hook(self, input, output):
        list_relu.append(input[0].nelement())

    def linear_hook(self, input, output):
        batch_size = input[0].size(0) if input[0].dim() == 2 else 1

        weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
        bias_ops = 0
        if self.bias != None:
            bias_ops = self.bias.nelement()

        flops = batch_size * (weight_ops + bias_ops)
        list_linear.append(flops)
        
    def pooling_hook(self, input, output):
        batch_size, input_channels, input_height, input_width = input[0].size()
        output_channels, output_height, output_width = output[0].size()

        kernel_ops = self.kernel_size * self.kernel_size
        bias_ops = 0
        params = 0
        flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size

        list_pooling.append(flops)

    # For bilinear upsample
    def upsample_hook(self, input, output):
        batch_size, input_channels, input_height, input_width = input[0].size()
        output_channels, output_height, output_width = output[0].size()

        flops = output_height * output_width * output_channels * batch_size * 12
        list_upsample.append(flops)

        
    handles = []
    def foo(net):
        childrens = list(net.children())
        if not childrens:
            if isinstance(net, torch.nn.Conv2d) or isinstance(net, torch.nn.ConvTranspose2d):
                handles.append(net.register_forward_hook(conv_hook))
            if isinstance(net, torch.nn.Linear):
                handles.append(net.register_forward_hook(linear_hook))
            if isinstance(net, torch.nn.BatchNorm2d):
                handles.append(net.register_forward_hook(bn_hook))
            if isinstance(net, torch.nn.ReLU) or isinstance(net, torch.nn.LeakyReLU):
                handles.append(net.register_forward_hook(relu_hook))
            if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d):
                handles.append(net.register_forward_hook(pooling_hook))
            if isinstance(net, torch.nn.Upsample):
                handles.append(net.register_forward_hook(upsample_hook))
            return
        for c in childrens:
            foo(c)
    
    foo(model)

def calculate_flops():
    total_flops = (sum(list_conv) + sum(list_bn) + sum(list_relu) + sum(list_linear) + sum(list_pooling) + sum(list_upsample) )
    return total_flops / 1e9
