using LinearAlgebra
using CUDA
using Transformers
using Random
import TextEncodeBase as TEB
using NeuralAttentionlib.Masks

extract_text_and_graph(sample::SampleType{>:Union{HasText, HasGraph}}) =
    (text = extract_text(sample), graph = extract_graph(sample))
extract_text(sample::SampleType{>:HasText}) = throw(MethodError(extract_text, sample))
extract_graph(sample::SampleType{>:HasGraph}) = throw(MethodError(extract_graph, sample))

function triple2graph(triples)
    labels = String[]
    seen_labels = Dict{String, Int}()
    edge_labels = String[]
    nodes = Int[]
    edges = NTuple{3,Int}[]
    for triple in triples
        @assert length(triple) == 3
        head, edge, tail = triple
        if haskey(seen_labels, head)     # Node labels are unique, if we have seen the node
            head_id = seen_labels[head]  # get the ID we previous assigned.
        else                             #
            push!(labels, head)          # Otherwise we add the node to the list of labels
            head_id = length(labels)     # and the ID is the index of that node label in
            seen_labels[head] = head_id  # the list of labels, and then mark the node as seen
            push!(nodes, head_id)        # and add it to the list of unique nodes (IDs).
        end                              #
        if haskey(seen_labels, tail)     # This is done for both head and tail.
            tail_id = seen_labels[tail]  #
        else                             #
            push!(labels, tail)          # edge labels can repeat, since the triples contain
            tail_id = length(labels)     # non-repeating edges. we get those repeatable nodes
            seen_labels[tail] = tail_id  # from the head and tail, then we add the edge label
            push!(nodes, tail_id)        # to the list of labels and get the ID. The ID of
        end                              # edge is also the index of that edge label in the
        push!(labels, edge)              # list of labels. we also save the triple as
        edge_id = length(labels)         # an NTuple{3} with their ID.
        push!(edges, (head_id, edge_id, tail_id))
    end
    return (; labels, nodes, edges)
end

function random_graph_linearization(graph; full = false)
    (; labels, nodes, edges) = graph                                 # we linearize the graph by first
    linear_graph = shuffle(nodes)                                    # random shuffle the nodes and then
    for triple in edges                                              # randomly insert the edge into the
        head, edge, tail = triple                                    # shuffled list of nodes. The only
        if full                                                      # insertion criterion is that either
            possible_location = 1:(length(linear_graph) + 1)         # the head or the tail need to appear
        else                                                         # before the edge. (so that the prediction
            anchor = rand((head, tail))                              # of edge can be done with at least 1 node
            location = findfirst(==(anchor), linear_graph)           # appearred.
            possible_location = (location:length(linear_graph)) .+ 1 # OTOH, a set full = true means the criterion
        end                                                          # complete ignored
        insert!(linear_graph, rand(possible_location), edge)
    end
    return (; labels, linear_graph, nodes, edges)
end

process_graph(textenc, sample::SampleType{>:HasGraph}; domain_token = DomainToken(sample), kws...) =
    process_graph(textenc, extract_graph(sample); domain_token, kws...)
function process_graph(textenc, graph; EOT_token = "[EOT]", NODE_token = "[NODE]", EDGE_token = "[EDGE]",
                       domain_token = nothing, require_end = false, END_token = domain_token,
                       rand_drop_elm = false,
                       shuffle_id = false, full_shuffle = false)
    @assert !isnothing(domain_token) && !isempty(domain_token) "domain token must be set and not empty."
    if full_shuffle isa Number
        full_shuffle = rand() < full_shuffle
    end
    (; labels, linear_graph, nodes, edges) = random_graph_linearization(graph; full = full_shuffle)
    if require_end isa Int
        require_end = length(labels) >= require_end
    end
    text_segments = labels[linear_graph]                               # order the labels with the order of
    tokenized_text_segments =                                          # linear graph and perform tokenization
        TEB.tokenize(textenc, TEB.Batch{TEB.Sentence}(text_segments))  # on each segment independently.
    num_of_id =
        sum(length, tokenized_text_segments) + # num of tokens of all segments, each token get an unique ID
        2 * length(tokenized_text_segments) +  # 2 extra begin/end-of-text token per segment
        1 +                                    # 1 domain token
        require_end                            # 1 end token if needed

    offset = 1                                                         # Each token would be assigned with an
    id_based_label = zeros(Int, length(labels))                        # unique ID. We use the ID of the first
    for (id, segment) in zip(linear_graph, tokenized_text_segments)    # word as the segment ID. The head/tail
        num_tokens = length(segment) + 2                               # would be pointing to the the segment ID,
        segment_id = offset + 1                                        # so here we precompute the unique ID the
        id_based_label[id] = segment_id                                # first token would be assigned to and
        offset += num_tokens                                           # construct a mapping between the old
    end                                                                # component ID and the new segment ID.

    if rand_drop_elm > 0 && rand() < 0.5
        p = Float64(rand_drop_elm)
        len = length(linear_graph)
        drop_n = round(Int, len * p)
        drop_idx = sort!(shuffle(1:len)[1:drop_n])
        dropped_linear_graph = copy(linear_graph)
        splice!(dropped_linear_graph, drop_idx)
    else
        dropped_linear_graph = linear_graph
        drop_idx = nothing
    end

    domain_token_id = 1
    unique_ids = Vector{Int32}(undef, offset) |> empty! |> Base.Fix2(push!, domain_token_id) #Int[ domain_token_id ]
    segment_ids = Vector{Int32}(undef, offset) |> empty! |> Base.Fix2(push!, domain_token_id) #Int[ domain_token_id ]
    # domain token treat as node type : 1=>node, 2=>edge
    types = Vector{Int}(undef, offset) |> empty! |> Base.Fix2(push!, 1) #Int[ 1 ]
    prevs = Vector{Int32}(undef, offset) |> empty! |> Base.Fix2(push!, 1) #Int[ 1 ]
    heads = Vector{Int32}(undef, offset) |> empty! |> Base.Fix2(push!, 1) #Int[ 1 ]
    tails = Vector{Int32}(undef, offset) |> empty! |> Base.Fix2(push!, 1) #Int[ 1 ]
    for (id, segment) in zip(linear_graph, tokenized_text_segments)
        isnode = id in nodes
        push!(segment, TEB.Token(EOT_token))
        pushfirst!(segment, TEB.Token(isnode ? NODE_token : EDGE_token))
        segment_id = id_based_label[id]
        num_tokens = length(segment)
        offset = segment_id - 1
        if !(id in dropped_linear_graph)
            continue
        end
        append!(unique_ids, Base.OneTo(num_tokens) .+ offset)
        append!(segment_ids, Iterators.repeated(segment_id, num_tokens))
        append!(types, Iterators.repeated(isnode ? 1 : 2, num_tokens))
        push!(prevs, 1); append!(prevs, Base.OneTo(num_tokens - 1) .+ offset) # domain token is the first token for all
        if isnode                                                             # segments
            append!(heads, Iterators.repeated(segment_id, num_tokens))
            append!(tails, Iterators.repeated(segment_id, num_tokens))        # nodes point to itself
        else
            edge = edges[findfirst(triple->triple[2] == id, edges)]
            head, _, tail = edge
            head_id = id_based_label[head]
            tail_id = id_based_label[tail]
            append!(heads, Iterators.repeated(head_id, num_tokens))           # all token is the same edge segment
            append!(tails, Iterators.repeated(tail_id, num_tokens))           # point to the same head and tail
        end
    end
    !isnothing(drop_idx) && splice!(tokenized_text_segments, drop_idx)
    tokenized_text_labels = TEB.nestedcall(TextEncoders.string_getvalue, tokenized_text_segments)
    output_tokens = foldl(append!, tokenized_text_labels; init = String[ domain_token ])
    if require_end
        push!(output_tokens, END_token)
        push!(types, 1)
        id = maximum(unique_ids) + 1
        push!(unique_ids, id)
        push!(prevs, 1)
        push!(segment_ids, id)
        push!(heads, id)
        push!(tails, id)
    end
    pred_mask = GenericSequenceMask(map(@view(output_tokens[begin:end-1])) do token
        return token == domain_token || token == EOT_token
    end)::GenericSequenceMask{2, Matrix{Bool}}
    if shuffle_id
        remap = shuffle(Int32(1):Int32(num_of_id))
        unique_ids = remap[unique_ids]
        prevs = remap[prevs]
        segment_ids = remap[segment_ids]
        heads = remap[heads]
        tails = remap[tails]
    end
    return (; label = output_tokens, type = types, id = unique_ids,
            prev = prevs, segment = segment_ids, head = heads, tail = tails,
            pred_mask, num_of_id = num_of_id)
end

process_text(textenc, sample::SampleType{>:HasText}; kws...) = process_text(textenc, extract_text(sample); kws...)
function process_text(textenc, text; EOT_token = "[EOT]", NODE_token = "[NODE]", domain_token = "[TEXT]",
                      rand_drop_elm = false,
                      shuffle_id = false, require_end = false, END_token = domain_token)
    tokenized_text = TEB.tokenize(textenc, TEB.Sentence(text))
    num_of_id = length(tokenized_text) + 3
    types = ones(Int, num_of_id)
    unique_ids = collect(Base.OneTo{Int32}(num_of_id))
    segment_ids = vcat(Int32(1), collect(Iterators.repeated(Int32(2), num_of_id - 1)))
    prevs = Int32[ 1; collect(Base.OneTo(num_of_id - 1)) ]
    heads = copy(segment_ids)
    tails = copy(segment_ids)

    tokenized_text_labels = TEB.nestedcall(TextEncoders.string_getvalue, tokenized_text)

    if rand_drop_elm > 0
        p = Float64(rand_drop_elm)
        len = length(tokenized_text_labels)
        drop_n = round(Int, len * p)
        _p = rand()
        if _p < 0.3
            drop_idx = sort!(shuffle(1:len)[1:drop_n])
            tokenized_text_labels[drop_idx] .= "<unk>"
        elseif _p < 0.8
            drop_n = div(drop_n, 2)
            if drop_n > 0
                drop_idx = sort!(shuffle(1:len)[1:drop_n])
                tokenized_text_labels[drop_idx] .= "<unk>"
            end
        end
    end

    output_tokens = [ domain_token; NODE_token; tokenized_text_labels; EOT_token ]
    if require_end
        num_of_id += 1
        push!(output_tokens, END_token)
        push!(types, 1)
        id = num_of_id
        push!(unique_ids, id)
        push!(prevs, 1)
        push!(segment_ids, id)
        push!(heads, id)
        push!(tails, id)
    end
    pred_mask = GenericSequenceMask(map(@view(output_tokens[begin:end-1])) do token
        return token == domain_token || token == EOT_token
    end)::GenericSequenceMask{2, Matrix{Bool}}
    if shuffle_id
        remap = shuffle(Int32(1):Int32(num_of_id))
        unique_ids = remap[unique_ids]
        prevs = remap[prevs]
        segment_ids = remap[segment_ids]
        heads = remap[heads]
        tails = remap[tails]
    end
    return (; label = output_tokens, type = types, id = unique_ids,
            prev = prevs, segment = segment_ids, head = heads, tail = tails,
            pred_mask, num_of_id = num_of_id)
end

function batch_process_graph(textenc, samples::AbstractVector{S};
                             EOT_token = "[EOT]", NODE_token = "[NODE]", EDGE_token = "[EDGE]",
                             require_end = false, END_token = nothing, full_shuffle = false,
                             rand_drop_elm = false,
                             kws...) where {S <: SampleType{>:HasGraph}}
    if isnothing(END_token)
        pgraphs = map(samples) do sample
            process_graph(textenc, sample; EOT_token, NODE_token, EDGE_token, require_end,
                          rand_drop_elm,
                          shuffle_id = false, full_shuffle)
        end
    else
        pgraphs = map(samples) do sample
            process_graph(textenc, sample; EOT_token, NODE_token, EDGE_token, require_end, END_token,
                          rand_drop_elm,
                          shuffle_id = false, full_shuffle)
        end
    end
    return _batch_process_data(textenc, pgraphs; kws...)
end
function batch_process_graph(textenc, graphs; EOT_token = "[EOT]", NODE_token = "[NODE]", EDGE_token = "[EDGE]",
                             domain_token = nothing, require_end = false, END_token = domain_token, full_shuffle = false,
                             rand_drop_elm = false,
                             kws...)
    @assert !isnothing(domain_token) && !isempty(domain_token) "domain token must be set and not empty."
    pgraphs = map(graphs) do g
        process_graph(textenc, g; EOT_token, NODE_token, EDGE_token, domain_token, require_end, END_token,
                      rand_drop_elm,
                      shuffle_id = false, full_shuffle)
    end
    return _batch_process_data(textenc, pgraphs; kws...)
end

batch_process_text(textenc, samples::AbstractVector{S}; kws...) where {S<:SampleType{>:HasText}} =
    batch_process_text(textenc, extract_text.(samples); kws...)
function batch_process_text(textenc, texts; EOT_token = "[EOT]", NODE_token = "[NODE]", domain_token = "[TEXT]",
                            rand_drop_elm = false,
                            require_end = false, END_token = domain_token, kws...)
    ptexts = map(texts) do t
        process_text(textenc, t; EOT_token, NODE_token, domain_token, shuffle_id = false, require_end, END_token,
                     rand_drop_elm)
    end
    return _batch_process_data(textenc, ptexts; kws...)
end

function _batch_process_data(textenc, pdata;
                             shuffle_id = false, remap_max = nothing, force_trunc = nothing)
    if shuffle_id isa Number
        shuffle_id = rand() < shuffle_id
    end
    max_length = something(force_trunc, typemax(Int))
    lens = map(x -> Int32(min(length(x.label), max_length))::Int32, pdata)
    max_length = maximum(lens)
    max_num_of_id = maximum(x -> x.num_of_id, pdata)
    if shuffle_id && !isnothing(remap_max)
        if isnothing(remap_max)
            remap_max = max_num_of_id
        end
        remap = shuffle(Int32(1):Int32(max(remap_max, max_num_of_id)))
    else
        remap = Int32(1):Int32(max_num_of_id)
    end
    labels = TEB.nested2batch(TEB.trunc_and_pad(map(x -> x.label, pdata), force_trunc, textenc.padsym))
    type    = TEB.nested2batch(TEB.trunc_and_pad(map(x -> x.type, pdata), force_trunc, 1))
    id      = remap[TEB.nested2batch(TEB.trunc_and_pad(map(x -> x.id, pdata), force_trunc, 1))]
    prev    = remap[TEB.nested2batch(TEB.trunc_and_pad(map(x -> x.prev, pdata), force_trunc, 1))]
    segment = remap[TEB.nested2batch(TEB.trunc_and_pad(map(x -> x.segment, pdata), force_trunc, 1))]
    head_label = TEB.nested2batch(TEB.trunc_and_pad(map(x -> x.head, pdata), force_trunc, 1))
    tail_label = TEB.nested2batch(TEB.trunc_and_pad(map(x -> x.tail, pdata), force_trunc, 1))
    head    = remap[head_label]
    tail    = remap[tail_label]
    pred_mask = GenericSequenceMask(TEB.nested2batch(TEB.trunc_and_pad(
        map(x -> reshape(x.pred_mask.mask, :), pdata), max_length - 1, false)))::GenericSequenceMask{3, Array{Bool,3}}
    mask = LengthMask(lens)
    label = lookup(OneHot, textenc.vocab, labels)
    type = OneHotArray{2}(type)
    return (; label, type, id, prev, segment, head, tail, pred_mask, attention_mask = mask, head_label, tail_label)
end

function maybe_free!(nt::Union{NamedTuple, Tuple})
    foreach(maybe_free!, nt)
    return nothing
end
maybe_free!(x::CuArray) = (CUDA.unsafe_free!(x); nothing)
maybe_free!(::Array) = nothing
maybe_free!(x::LinearAlgebra.Adjoint) = maybe_free!(parent(x))
maybe_free!(m::LengthMask) = maybe_free!(m.len)
maybe_free!(m::RevLengthMask) = maybe_free!(m.len)
maybe_free!(m::GenericSequenceMask) = maybe_free!(m.mask)
maybe_free!(x::OneHotArray) = maybe_free!(x.onehots)
maybe_free!(::Union{Nothing, Number, Ref}) = nothing

function prealloc(max_length, batch_size)
    onehots = Array{Int32}(undef, max_length, batch_size) |> CUDA.Mem.pin
    type = Array{Int32}(undef, max_length, batch_size) |> CUDA.Mem.pin
    id = Array{Int32}(undef, max_length, batch_size) |> CUDA.Mem.pin
    prev = Array{Int32}(undef, max_length, batch_size) |> CUDA.Mem.pin
    segment = Array{Int32}(undef, max_length, batch_size) |> CUDA.Mem.pin
    head = Array{Int32}(undef, max_length, batch_size) |> CUDA.Mem.pin
    tail = Array{Int32}(undef, max_length, batch_size) |> CUDA.Mem.pin
    pred_mask = Array{Bool}(undef, 1, max_length, batch_size) |> CUDA.Mem.pin
    len_mask = Array{Int32}(undef, batch_size) |> CUDA.Mem.pin
    head_label = Array{Int32}(undef, max_length, batch_size) |> CUDA.Mem.pin
    tail_label = Array{Int32}(undef, max_length, batch_size) |> CUDA.Mem.pin
    return (; onehots, type, id, prev, segment, head, tail, pred_mask, len_mask, head_label, tail_label)
end

function toprealloc(prealloc::Array{Int32}, data::OneHotArray{K}) where K
    copyto!(prealloc, reinterpret(Int32, data.onehots))
    return OneHotArray(
        unsafe_wrap(Array{OneHot{K}, ndims(prealloc)}, convert(Ptr{OneHot{K}}, pointer(prealloc)), size(data.onehots)))
end
function toprealloc(prealloc::Array{Bool}, data::GenericSequenceMask)
    copyto!(prealloc, data.mask)
    N = ndims(prealloc)
    return GenericSequenceMask{N}(unsafe_wrap(Array{Bool, N}, pointer(prealloc), size(data.mask)))
end
function toprealloc(prealloc::Array{Int32}, data::LengthMask)
    copyto!(prealloc, data.len)
    N = ndims(prealloc)
    return LengthMask(unsafe_wrap(Array{Int32, N}, pointer(prealloc), size(data.len)))
end
function toprealloc(prealloc::Array{T}, data) where T
    copyto!(prealloc, data)
    return unsafe_wrap(Array{T, ndims(prealloc)}, pointer(prealloc), size(data))
end
function toprealloc(prealloc::NamedTuple, data)
    label = toprealloc(prealloc.onehots, data.label)
    type = toprealloc(prealloc.type, data.type)
    id = toprealloc(prealloc.id, data.id)
    prev = toprealloc(prealloc.prev, data.prev)
    segment = toprealloc(prealloc.segment, data.segment)
    head = toprealloc(prealloc.head, data.head)
    tail = toprealloc(prealloc.tail, data.tail)
    pred_mask = toprealloc(prealloc.pred_mask, data.pred_mask)
    attention_mask = toprealloc(prealloc.len_mask, data.attention_mask)
    head_label = toprealloc(prealloc.head_label, data.head_label)
    tail_label = toprealloc(prealloc.tail_label, data.tail_label)
    return (; label, type, id, prev, segment, head, tail, pred_mask, attention_mask, head_label, tail_label)
end
