import torch
from models.batch_renorm import BatchRenorm2D


def replace_bn_with_brn(
        m, name="", momentum=0.1, r_d_max_inc_step=0.0001, r_max=1.0,
        d_max=0.0, max_r_max=3.0, max_d_max=5.0):
    for child_name, child in m.named_children():
        if isinstance(child, torch.nn.BatchNorm2d):
            setattr(m, child_name, BatchRenorm2D(
                child.num_features,
                gamma=child.weight,
                beta=child.bias,
                running_mean=child.running_mean,
                running_var=child.running_var,
                eps=child.eps,
                momentum=momentum,
                r_d_max_inc_step=r_d_max_inc_step,
                r_max=r_max,
                d_max=d_max,
                max_r_max=max_r_max,
                max_d_max=max_d_max
            ))
        else:
            replace_bn_with_brn(child, child_name, momentum, r_d_max_inc_step, r_max, d_max,
                                max_r_max, max_d_max)


def change_brn_pars(
        m, name="", momentum=0.1, r_d_max_inc_step=0.0001, r_max=1.0,
        d_max=0.0):
    for target_name, target_attr in m.named_children():
        if isinstance(target_attr, BatchRenorm2D):
            target_attr.momentum = torch.tensor(momentum, requires_grad=False)
            target_attr.r_max = torch.tensor(r_max, requires_grad=False)
            target_attr.d_max = torch.tensor(d_max, requires_grad=False)
            target_attr.r_d_max_inc_step = r_d_max_inc_step

        else:
            change_brn_pars(target_attr, target_name, momentum, r_d_max_inc_step, r_max, d_max)


def change_bn_momentum(m, momentum=0.001):
    for target_name, target_attr in m.named_children():
        if isinstance(target_attr, torch.nn.BatchNorm2d):
            target_attr.momentum = momentum
        else:
            change_bn_momentum(target_attr, momentum)


def freeze_conv_dw(m):
    for target_name, target_attr in m.named_children():
        if isinstance(target_attr, torch.nn.Conv2d):
            if target_attr.kernel_size == (3, 3):
                for p in target_attr.parameters():
                    p.requires_grad = False
        else:
            freeze_conv_dw(target_attr)


def set_brn_to_train(m, name=""):
    for target_name, target_attr in m.named_children():
        if isinstance(target_attr, BatchRenorm2D):
            target_attr.train()
        else:
            set_brn_to_train(target_attr, target_name)


def set_brn_to_eval(m, name=""):
    for target_name, target_attr in m.named_children():
        if isinstance(target_attr, BatchRenorm2D):
            target_attr.eval()
        else:
            set_brn_to_eval(target_attr, target_name)


def set_bn_to(m, name="", phase="train"):
    for target_name, target_attr in m.named_children():
        if isinstance(target_attr, torch.nn.BatchNorm2d):
            if phase == "train":
                target_attr.train()
            else:
                target_attr.eval()
        else:
            set_bn_to(target_attr, target_name, phase)