import torch.fx as fx
from torch.fx.node import _get_qualified_name


def eliminate_detach(fx_graph: fx.GraphModule):
    for node in fx_graph.graph.nodes:
        if node.op == 'call_function':
            if _get_qualified_name(node.target) == 'torch.ops.aten.detach.default':
                node.replace_all_uses_with(node.args[0])

    fx_graph.graph.eliminate_dead_code()

    return fx_graph