import torch
from torch.fx.node import _get_qualified_name


def mf_batch_norm_(input, weight, bias, running_mean, running_var, training, momentum, eps):
    if len(input.shape) == 4:
        if not (input.shape[1] == weight.shape[0] == bias.shape[0] == running_mean.shape[0] ==
                running_var.shape[0]):
            raise RuntimeError("shape not right for batch_norm")
    return torch.ops.aten.native_batch_norm.default(input, weight, bias, running_mean, running_var,
                                                    training, momentum, eps)


def fix_batch_norm(fx_module: torch.fx.GraphModule):

    for node in fx_module.graph.nodes:
        if node.op == 'call_function':
            if "torch.ops.aten.native_batch_norm.default" in _get_qualified_name(node.target):
                node.target = mf_batch_norm_

    fx_module.recompile()

    return fx_module
