"""
Adapted in from the original NASBench-201 code repository at 
https://github.com/D-X-Y/AutoDL-Projects/tree/bc4c4692589e8ee7d6bab02603e69f8e5bd05edc
"""

import copy, torch
import torch.nn as nn
import numpy as np


def get_model_flops(model, input_shape, transfer_devices=False, device="cpu"):
    """ For a given model and input shape, estimates the number of floating point operations that will be performed
    per second of computation, counted in millions, i.e. a value of 1. indicates approximately 1 million floating
    point operations per second. Note that calls to this function do advance the state of the PyTorch PRNG. """
    model = add_flops_counting_methods(model)
    model.eval()

    cache_inputs = torch.rand(*input_shape)
    if transfer_devices: 
        cache_inputs = cache_inputs.to(device)
    with torch.no_grad():
        _____ = model(cache_inputs)
    FLOPs = compute_average_flops_cost(model) / 1e6

    torch.cuda.empty_cache()
    model.apply(_remove_hook_function)
    return FLOPs


# ---- Public functions
def add_flops_counting_methods(model):
    model.__batch_counter__ = 0
    _add_batch_counter_hook_function(model)
    model.apply(_add_flops_counter_variable_or_reset)
    model.apply(_add_flops_counter_hook_function)
    return model


def compute_average_flops_cost(model):
    """
    A method that will be available after add_flops_counting_methods() is called on a desired net object.
    Returns current mean flops consumption per image.
    """
    batches_count = model.__batch_counter__
    flops_sum = 0
    # or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
    for module in model.modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
                or isinstance(module, torch.nn.Conv1d) \
                or hasattr(module, 'calculate_flop_self'):
            flops_sum += module.__flops__
    return flops_sum / batches_count


# ---- Internal functions
def _pool_flops_counter_hook(pool_module, inputs, output):
    batch_size = inputs[0].size(0)
    kernel_size = pool_module.kernel_size
    out_C, output_height, output_width = output.shape[1:]
    assert out_C == inputs[0].size(1), '{:} vs. {:}'.format(out_C, inputs[0].size())

    overall_flops = batch_size * out_C * output_height * output_width * kernel_size * kernel_size
    pool_module.__flops__ += overall_flops


def _self_calculate_flops_counter_hook(self_module, inputs, output):
    overall_flops = self_module.calculate_flop_self(inputs[0].shape, output.shape)
    self_module.__flops__ += overall_flops


def _fc_flops_counter_hook(fc_module, inputs, output):
    batch_size = inputs[0].size(0)
    xin, xout = fc_module.in_features, fc_module.out_features
    assert xin == inputs[0].size(1) and xout == output.size(1), 'IO=({:}, {:})'.format(xin, xout)
    overall_flops = batch_size * xin * xout
    if fc_module.bias is not None:
        overall_flops += batch_size * xout
    fc_module.__flops__ += overall_flops


def _conv1d_flops_counter_hook(conv_module, inputs, outputs):
    batch_size = inputs[0].size(0)
    outL = outputs.shape[-1]
    [kernel] = conv_module.kernel_size
    in_channels = conv_module.in_channels
    out_channels = conv_module.out_channels
    groups = conv_module.groups
    conv_per_position_flops = kernel * in_channels * out_channels / groups

    active_elements_count = batch_size * outL
    overall_flops = conv_per_position_flops * active_elements_count

    if conv_module.bias is not None:
        overall_flops += out_channels * active_elements_count
    conv_module.__flops__ += overall_flops


def _conv2d_flops_counter_hook(conv_module, inputs, output):
    batch_size = inputs[0].size(0)
    output_height, output_width = output.shape[2:]

    kernel_height, kernel_width = conv_module.kernel_size
    in_channels = conv_module.in_channels
    out_channels = conv_module.out_channels
    groups = conv_module.groups
    conv_per_position_flops = kernel_height * kernel_width * in_channels * out_channels / groups

    active_elements_count = batch_size * output_height * output_width
    overall_flops = conv_per_position_flops * active_elements_count

    if conv_module.bias is not None:
        overall_flops += out_channels * active_elements_count
    conv_module.__flops__ += overall_flops


def _batch_counter_hook(module, inputs, output):
    # Can have multiple inputs, getting the first one
    inputs = inputs[0]
    batch_size = inputs.shape[0]
    module.__batch_counter__ += batch_size


def _add_batch_counter_hook_function(module):
    if not hasattr(module, '__batch_counter_handle__'):
        handle = module.register_forward_hook(_batch_counter_hook)
        module.__batch_counter_handle__ = handle


def _add_flops_counter_variable_or_reset(module):
    if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
            or isinstance(module, torch.nn.Conv1d) \
            or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
            or hasattr(module, 'calculate_flop_self'):
        module.__flops__ = 0


def _add_flops_counter_hook_function(module):
    if isinstance(module, torch.nn.Conv2d):
        if not hasattr(module, '__flops_handle__'):
            handle = module.register_forward_hook(_conv2d_flops_counter_hook)
            module.__flops_handle__ = handle
    elif isinstance(module, torch.nn.Conv1d):
        if not hasattr(module, '__flops_handle__'):
            handle = module.register_forward_hook(_conv1d_flops_counter_hook)
            module.__flops_handle__ = handle
    elif isinstance(module, torch.nn.Linear):
        if not hasattr(module, '__flops_handle__'):
            handle = module.register_forward_hook(_fc_flops_counter_hook)
            module.__flops_handle__ = handle
    elif isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d):
        if not hasattr(module, '__flops_handle__'):
            handle = module.register_forward_hook(_pool_flops_counter_hook)
            module.__flops_handle__ = handle
    elif hasattr(module, 'calculate_flop_self'):  # self-defined module
        if not hasattr(module, '__flops_handle__'):
            handle = module.register_forward_hook(_self_calculate_flops_counter_hook)
            module.__flops_handle__ = handle


def _remove_hook_function(module):
    hookers = ['__batch_counter_handle__', '__flops_handle__']
    for hooker in hookers:
        if hasattr(module, hooker):
            handle = getattr(module, hooker)
            handle.remove()
    keys = ['__flops__', '__batch_counter__', '__flops__'] + hookers
    for ckey in keys:
        if hasattr(module, ckey): delattr(module, ckey)
