import torch.nn as nn

def print_requires_grad_per_layer(network):
    for n, p in network.named_parameters():
        print(n, "->", p.requires_grad)


def _print_is_training(name, net):
    if len(list(net.children())) == 0:
        print(name, "->", net.training)
    else:
        for n, nn in net.named_children():
            _print_is_training(n, nn)


def print_is_training_per_layer(network):
    _print_is_training("network", network)


def check_equal_running_stats(net1, net2):
    if list(net1.children()) == [] and list(net2.children()) == []:
        if isinstance(net1, nn.BatchNorm2d) and isinstance(net2, nn.BatchNorm2d):
            print("***")
            if net1.running_mean.data.ne(net2.running_mean.data).sum() > 0:
                print("Mean not equal!!")
            if net1.running_var.data.ne(net2.running_var.data).sum() > 0:
                print("VAR not equal!!")

    else:
        for c1, c2 in zip(net1.children(), net2.children()):
            check_equal_running_stats(c1, c2)


class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


class Unflatten(nn.Module):
    def __init__(self, channel, height, width):
        super(Unflatten, self).__init__()
        self.channel = channel
        self.height = height
        self.width = width

    def forward(self, x):
        return x.view(x.size(0), self.channel, self.height, self.width)
