from __future__ import division

from functools import reduce
import operator

import torch

count_ops = 0
count_params = 0


def length_gen(gen):
    return sum(1 for _ in gen)


def _is_leaf(model):
    return length_gen(model.children()) == 0


def measure_model(model, H, W, encoder_feature_dim, action_shape, flag_only=False, C=3, is_critic=False, is_actor=False):
    global count_ops, count_params
    count_ops = 0
    count_params = 0
    data = torch.zeros(1, C, H, W).to('cuda')
    ops_layers = []

    model.eval()
    def modify_forward(model):
        for child in model.children():
            if _is_leaf(child):
                def new_forward(m):
                    def lambda_forward(x, *input):
                        nonlocal ops_layers
                        ops_layers += [count_ops]
                        _measure_layer(m, x, flag_only)
                        return m.old_forward(x, *input)

                    return lambda_forward

                child.old_forward = child.forward
                child.forward = new_forward(child)
            else:
                modify_forward(child)

    def restore_forward(model):
        for child in model.children():
            # leaf node
            if _is_leaf(child) and hasattr(child, 'old_forward'):
                child.forward = child.old_forward
                child.old_forward = None
            else:
                restore_forward(child)

    modify_forward(model)
    if is_critic:
        data = torch.zeros((1, encoder_feature_dim)).to('cuda')
        action = torch.zeros((1, action_shape[0])).to('cuda')
        model.forward(data, action)
    elif is_actor:
        data = torch.zeros((1, encoder_feature_dim)).to('cuda')
        model.forward(data)
    else:
        model.forward(data)
    restore_forward(model)

    return count_ops, count_params, ops_layers


def _layer_name(layer):
    layer_str = str(layer)
    type_name = layer_str[:layer_str.find('(')].strip()
    return type_name


def _layer_param(model):
    return sum([reduce(operator.mul, i.size(), 1) for i in model.parameters()])


# The input batch size should be 1 to call this function
def _measure_layer(layer, x, flag_only):
    global count_ops, count_params
    delta_ops = 0
    type_name = _layer_name(layer)
    multi_add = 1

    # ops_conv
    if type_name in ['Conv2d']:
        out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) /
                    layer.stride[0] + 1)
        out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) /
                    layer.stride[1] + 1)
        delta_ops = x.size()[1] * layer.out_channels * \
                    layer.kernel_size[0] * layer.kernel_size[1] * \
                    out_h * out_w / layer.groups * multi_add

    # ops_nonlinearity
    elif type_name in ['ReLU', 'ChannelController2d']:
        delta_ops = x.numel()

    # ops_pooling
    elif type_name in ['AvgPool2d', 'MaxPool2d']:
        in_w = x.size()[2]
        kernel_ops = layer.kernel_size * layer.kernel_size
        out_w = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1)
        out_h = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1)
        delta_ops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops

    elif type_name in ['AdaptiveAvgPool2d']:
        delta_ops = x.size()[0] * x.size()[1] * x.size()[2] * x.size()[3]

    # ops_linear
    elif type_name in ['Linear']:
        weight_ops = layer.weight.numel() * multi_add
        bias_ops = layer.bias.numel()
        delta_ops = x.size()[0] * (weight_ops + bias_ops)

    # ops_nothing
    elif type_name in ['BatchNorm1d', 'BatchNorm2d', 'Dropout2d', 'DropChannel', 'Dropout',
                       'CompressedSCU2d']:
        # CompressedSCU2d has zero FLOPs when it's leaf
        pass

    # unknown layer type
    else:
        return
        # raise TypeError('unknown layer type: %s' % type_name)

    if flag_only:
        if hasattr(layer, 'measure') and layer.measure:
            count_ops += delta_ops
            count_params += _layer_param(layer)
    else:
        count_ops += delta_ops
        count_params += _layer_param(layer)
    return