# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# adapted from OFA: https://github.com/mit-han-lab/once-for-all

import torch.nn as nn

from .activations import *

def make_divisible(v, divisor=8, min_value=1):
    """
    forked from slim:
    https://github.com/tensorflow/models/blob/\
    0344c5503ee55e24f0de7f37336a6e08f10976fd/\
    research/slim/nets/mobilenet/mobilenet.py#L62-L69
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


def sub_filter_start_end(kernel_size, sub_kernel_size):
    center = kernel_size // 2
    dev = sub_kernel_size // 2
    start, end = center - dev, center + dev + 1
    assert end - start == sub_kernel_size
    return start, end


def get_net_device(net):
    return net.parameters().__next__().device


def int2list(val, repeat_time=1):
    if isinstance(val, list):
        return val
    elif isinstance(val, tuple):
        return list(val)
    else:
        return [val for _ in range(repeat_time)]



def get_same_padding(kernel_size):
    if isinstance(kernel_size, tuple):
        assert len(kernel_size) == 2, 'invalid kernel size: %s' % kernel_size
        p1 = get_same_padding(kernel_size[0])
        p2 = get_same_padding(kernel_size[1])
        return p1, p2
    assert isinstance(kernel_size, int), 'kernel size should be either `int` or `tuple`'
    assert kernel_size % 2 > 0, 'kernel size should be odd number'
    return kernel_size // 2


def copy_bn(target_bn, src_bn):
    # d = target_bn.d
    target_bn.weight.data.copy_(src_bn.weight.data[:target_bn.weight.data.shape[0]])
    target_bn.bias.data.copy_(src_bn.bias.data[:, :target_bn.weight.data.shape[0]])
    # target_bn.running_mean.data.copy_(src_bn.running_mean.data[:feature_dim])
    # target_bn.running_var.data.copy_(src_bn.running_var.data[:feature_dim])


def build_activation(act_func, inplace=True):
    if act_func == 'relu':
        return nn.ReLU6(inplace=inplace) #(inplace=inplace)
    elif act_func == 'relu6':
        return nn.ReLU6(inplace=inplace)
    elif act_func == 'tanh':
        return nn.Tanh()
    elif act_func == 'sigmoid':
        return nn.Sigmoid()
    elif act_func == 'h_swish':
        return Hswish(inplace=inplace)
    elif act_func == 'h_sigmoid':
        return Hsigmoid(inplace=inplace)
    elif act_func == 'swish':
        return MemoryEfficientSwish()
    elif act_func is None:
        return None
    else:
        raise ValueError('do not support: %s' % act_func)


def drop_connect(inputs, p, training):
    """Drop connect.
        Args:
            input (tensor: BCWH): Input of this structure.
            p (float: 0.0~1.0): Probability of drop connection.
            training (bool): The running mode.
        Returns:
            output: Output after drop connection.
    """
    assert 0 <= p <= 1, 'p must be in range of [0,1]'
    if not training:
        return inputs
    batch_size = inputs.shape[0]
    keep_prob = 1.0 - p

    # generate binary_tensor mask according to probability (p for 0, 1-p for 1)
    random_tensor = keep_prob
    random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
    binary_tensor = torch.floor(random_tensor)

    output = inputs / keep_prob * binary_tensor
    return output


