import torch
from torch.fx.node import _get_qualified_name


def fix_addmm(fx_module: torch.fx.GraphModule):

    for node in fx_module.graph.nodes:
        if node.op == 'call_function':
            if "torch.ops.aten.addmm.default" in _get_qualified_name(node.target):
                node.target = torch.ops.aten.mm.default
                bias = node.args[0]
                node.args = (node.args[1], node.args[2])

                with fx_module.graph.inserting_after(node):
                    add_bias_node = fx_module.graph.call_function(torch.ops.aten.add.Tensor,
                                                                  args=(node, bias))

                    node.replace_all_uses_with(add_bias_node)

                    add_bias_node.update_arg(0, node)

    fx_module.recompile()

    return fx_module
