import torch
import torch.nn as nn
from torch.nn import init
from torch.nn.modules import module
from math import sqrt


def initialize_conv(weight, weight_mask, sqrt_sparsity=False):
    with torch.no_grad():
        sparsity = weight_mask.sum().item() / weight_mask.numel()
    if sparsity == 0:
        # print('WARNING: ', weight_mask.shape, ' has no parameters left')
        return
    if sqrt_sparsity:
        sparsity = sqrt(sparsity)
    # n = weight.shape[1] * weight.shape[2] * weight.shape[3] * sparsity
    # stdv = sqrt(6.0 / n)
    # nn.init.uniform_(weight, -stdv, stdv)
    with torch.no_grad():
        weight /= sqrt(sparsity)


def channel_shuffle(x, groups):
    batch, channels, height, width = x.size()
    assert (channels % groups == 0)
    channels_per_group = channels // groups
    x = x.view(batch, groups, channels_per_group, height, width)
    x = torch.transpose(x, 1, 2).contiguous()
    x = x.view(batch, channels, height, width)

    return x


def round_groups(channels, groups):
    for rounded_groups in range(groups, 0, -1):
        if channels % rounded_groups == 0:
            return rounded_groups


def round_channels(channels, divisor=8):
    rounded_channels = max(
        int(channels + divisor / 2.0) // divisor * divisor, divisor)
    if float(rounded_channels) < 0.9 * channels:
        rounded_channels += divisor
    return rounded_channels
