processed_graph_to_graph(graph; debug = true) = processed_graph_to_triple(graph; debug)
function processed_graph_to_triple(graph; NODE_token = "[NODE]", EDGE_token = "[EDGE]", EOT_token = "[EOT]", debug = false)
    graph = NamedTuple{(:label, :type, :id, :prev, :segment, :head, :tail)}(graph)
    nid2oid = Dict(enumerate(sort!(collect(Set{Int}(graph.segment)))))
    oid2nid = Dict((v, k) for (k, v) in nid2oid)

    nodes = Vector{String}[]
    edges = Tuple{Int, Vector{String}, Int}[]
    sentences = Vector{Vector{String}}(undef, length(oid2nid))
    for (oid, nid) in oid2nid
        str = String[]
        for (label, type, id, prev, segment, head, tail) in zip(graph...)
            if segment == oid
                if !(label in (NODE_token, EDGE_token, EOT_token))
                    push!(str, label)
                end
                if id == segment
                    sentences[nid] = str
                    if type == 1
                        push!(nodes, str)
                    else
                        if haskey(oid2nid, head) && haskey(oid2nid, tail)
                            push!(edges, (oid2nid[head], str, oid2nid[tail]))
                        end
                    end
                end
            end
        end
    end

    sentences = map(x->strip(replace(join(x), '▁'=>' ')), sentences)
    nodes = map(x->strip(replace(join(x), '▁'=>' ')), nodes)
    edges = map(x->(x[1], strip(replace(join(x[2]), '▁'=>' ')), x[3]), filter(x->!isempty(x[2]), edges))
    return to_triples((; nodes, edges, sentences))
end

function to_triples(graph)
    (; nodes, edges, sentences) = graph
    triples = Set{NTuple{3, String}}()
    for (head, edge, tail) in edges
        push!(triples, (sentences[head], edge, sentences[tail]))
    end
    return collect(triples)
end

function _inference(textenc, model, data, domain_token; max_length = 100, max_component = 8,
                    ends_at = domain_token, debug = true)
    enc = encoder_forward(model, textenc, data)
    node_id, edge_id, ends_id = lookup(textenc.vocab, ("[NODE]", "[EDGE]", ends_at))
    @assert node_id != edge_id
    @assert lookup(textenc.vocab, node_id) == "[NODE]"
    @assert lookup(textenc.vocab, edge_id) == "[EDGE]"
    tokens = (
        label = [ domain_token ], type = Int32[ 1 ],
        id = Int32[ 1 ], prev = Int32[ 1 ], segment = Int32[ 1 ],
        head = Int32[ 1 ], tail = Int32[ 1 ]
    )
    unknown_ids = Set{Int32}()
    seen_ids = Set{Int32}()
    all_id_embedding = model.embed.id_embed(max_length) |> todevice
    new_id = 2
    i = 1
    for i = 1:max_component
        (cont, tokens, decs, new_id) =
            inference_component!(model, textenc, tokens, enc, new_id, all_id_embedding, seen_ids, unknown_ids;
                                 EOT_token = "[EOT]", node_id, edge_id, ends_id, debug)
        if !cont
            debug && println("generate remaining nodes")
            for id in unknown_ids
                if new_id > max_length
                    debug && println("max_length exceed")
                    break
                end
                push_token!(tokens, (;
                    label   = "[NODE]",
                    type    = 1,
                    id      = id,
                    prev    = 1,
                    segment = id,
                    head    = id,
                    tail    = id,
                ))
                tokens, decs, new_id = inference_text!(model, textenc, tokens, enc, new_id;
                                                           EOT_token = "[EOT]", max_length)
            end
            break
        end
    end
    debug && i == max_component && println("generate end by max_component")
    debug && @show unknown_ids
    debug && @show seen_ids
    return tokens
end

function encoder_forward(model, textenc, tokens)
    (; label, type, id, prev, segment, head, tail) = tokens
    onehot_labels = lookup(OneHot, textenc.vocab, label)
    onehot_types = OneHotArray{2}(type)
    enc_input = merge(tokens, (label = onehot_labels, type = onehot_types)) |> todevice
    embed = model.embed(enc_input)
    enc = model.seq2seq.encoder(embed).hidden_state
    maybe_free!(enc_input)
    return enc
end

function decoder_forward(model, textenc, enc, tokens)
    (; label, type, id, prev, segment, head, tail) = tokens
    onehot_labels = lookup(OneHot, textenc.vocab, label)
    onehot_types = OneHotArray{2}(type)
    dec_input = merge(tokens, (label = onehot_labels, type = onehot_types)) |> todevice
    dec_embed = model.embed(dec_input)
    decs = model.seq2seq.decoder(merge(dec_embed, (memory = enc,)))
    maybe_free!(dec_input)
    return decs
end

function inference_label_step(model, tokens, decs, new_id)
    id_embedding = reshape(model.embed.id_embed(todevice([@view(tokens.id[2:end]); new_id])), Val(3))
    label_state = (; hidden_state = decs.hidden_state, id_embedding)
    stage1 = model.stage1(label_state)
    maybe_free!(id_embedding)
    return stage1
end

function inference_struct_step(model, tokens, stage1_state, new_segment_id, all_id_embedding)
    segment_embedding = reshape(model.embed.id_embed(todevice([@view(tokens.segment[2:end]); new_segment_id])), Val(3))
    struct_state = (; hidden_state = stage1_state, segment_embedding, id_embedding = all_id_embedding)
    stage2 = model.stage2(struct_state)
    maybe_free!(segment_embedding)
    return stage2
end

function push_token!(tokens, new_token)
    push!(tokens.label, new_token.label)
    push!(tokens.type, new_token.type)
    push!(tokens.id, new_token.id)
    push!(tokens.prev, new_token.prev)
    push!(tokens.segment, new_token.segment)
    push!(tokens.head, new_token.head)
    push!(tokens.tail, new_token.tail)
    return tokens
end

function inference_component!(model, textenc, tokens, enc, new_id, all_id_embedding, seen_ids, unknown_ids;
                              EOT_token = "[EOT]", node_id, edge_id, ends_id, debug = true)
    max_length = size(all_id_embedding, 2)
    decs = decoder_forward(model, textenc, enc, tokens)
    new_id > max_length && return (false, tokens, decs, new_id)
    stage1 = inference_label_step(model, tokens, decs, new_id)
    label_logit = collect(@view(stage1.label_logits[:, end, :]))
    node_prob = label_logit[node_id]
    edge_prob = label_logit[edge_id]
    ends_prob = label_logit[ends_id]
    if ends_prob >= node_prob && ends_prob >= edge_prob
        debug && println("generated end token: stop")
        return (false, tokens, decs, new_id)
    end
    push!(seen_ids, new_id)
    isnode = node_prob >= edge_prob
    label = isnode ? "[NODE]" : "[EDGE]"
    if !isnode
        max_new_id = min(new_id + 2, max_length)
        stage2 = inference_struct_step(model, tokens, stage1.hidden_state,
                                       new_id, @view(all_id_embedding[:, 1:max_new_id]))
        head_logits = reshape(collect(@view stage2.head_logits[:, end, :]), :)
        tail_logits = reshape(collect(@view stage2.tail_logits[:, end, :]), :)
        head_logits[new_id] = typemin(eltype(head_logits))
        tail_logits[new_id] = typemin(eltype(tail_logits))
        head_id = Flux.onecold(head_logits)[]
        tail_id = Flux.onecold(tail_logits)[]
        if head_id == tail_id
            # label_logit is edge, but we failed to inference the correct head and tail
            # so we set it as node, which is less harm to the overall prediction.
            label = "[NODE]"
            isnode = true
            head_id = tail_id = new_id
        else
            if !(head_id in seen_ids)
                push!(seen_ids, head_id)
                push!(unknown_ids, head_id)
            end
            if !(tail_id in seen_ids)
                push!(seen_ids, tail_id)
                push!(unknown_ids, tail_id)
            end
        end
    else
        head_id = tail_id = new_id
    end
    push_token!(tokens, (;
        label,
        type    = isnode ? 1 : 2,
        id      = new_id,
        prev    = 1,
        segment = new_id,
        head    = head_id,
        tail    = tail_id,
    ))
    new_id = max(new_id, head_id, tail_id) + 1
    tokens, decs, new_id = inference_text!(model, textenc, tokens, enc, new_id; EOT_token, max_length)
    return (true, tokens, decs, new_id)
end

function inference_text!(model, textenc, tokens, enc, new_id; EOT_token = "[EOT]", max_length)
    decs = nothing
    while new_id <= max_length
        decs = decoder_forward(model, textenc, enc, tokens)
        stage1 = inference_label_step(model, tokens, decs, new_id)
        label_logit = collect(@view(stage1.label_logits[:, end, :]))
        label = lookup(textenc.vocab, Flux.onecold(label_logit)[])
        push_token!(tokens, (;
            label,
            type    = @inbounds(tokens.type[end]),
            id      = new_id,
            prev    = @inbounds(tokens.id[end]),
            segment = @inbounds(tokens.segment[end]),
            head    = @inbounds(tokens.head[end]),
            tail    = @inbounds(tokens.tail[end]),
        ))
        new_id += 1
        label == EOT_token && break
    end
    return (tokens, decs, new_id)
end

function inference_g2d(textenc, model, graph; max_length = 100, domain_token, ends_at = domain_token, debug = true, max_component = 8)
    data = TextGraphBART.process_graph(textenc, graph; domain_token)
    return _inference(textenc, model, data, domain_token; max_length, ends_at, debug, max_component)
end

function inference_t2d(textenc, model, text; max_length = 100, domain_token, ends_at = domain_token, debug = true, max_component = 8)
    data = TextGraphBART.process_text(textenc, text)
    return _inference(textenc, model, data, domain_token; max_length, ends_at, debug, max_component)
end

function inference_g2t(textenc, model, graph; max_length = 1000, domain_token = "[TEXT]", debug = true)
    graph = inference_g2d(textenc, model, graph; max_length, domain_token, max_component = 1, debug)
    labels = graph.label
    starti = 2
    if labels[starti] == "[NODE]"
        starti += 1
    end
    endi = length(labels)
    if labels[endi] == "[EOT]"
        endi -= 1
    end
    pred = @view labels[starti:endi]
    return strip(replace(join(pred), '▁'=>' '))
end
