using Unicode
using InternedStrings
using EzXML
using TextEncodeBase

#### Dataset type definition

struct WebNLG2020 <: DataSetConfig{Union{HasTrainSet, HasDevSet, HasTestSet}}
    path::String
end

struct WebNLG2020DataSet{T} <: DataSetType{Union{Indexable, ReIterable}}
    dataset::T
end

Base.length(data::WebNLG2020DataSet) = length(data.dataset)
Base.getindex(data::WebNLG2020DataSet, i) = WebNLG2020Sample(data.dataset[i])
Base.getindex(data::WebNLG2020DataSet, is::AbstractVector) = map(i->data[i], is)
Base.iterate(data::WebNLG2020DataSet, i = 1) = 0 < i <= length(data) ? (data[i], i+1) : nothing

struct WebNLG2020Sample{T} <: SampleType{Union{HasText, HasGraph}}
    sample::T
end

function list_webnlg_files(path)
    return mapfoldl(append!, walkdir(path); init = String[]) do (root, dirs, files)
        Iterators.map(Base.Fix1(joinpath, root), files)
    end
end

function camelsplit(text)
    subtext = String[]
    i = 1
    for j in findall(isuppercase, text)
        push!(subtext, @view text[i:prevind(text, j)])
        i = j
    end
    push!(subtext, @view text[i:end])
    @assert join(subtext) == text
    return join(Iterators.map(lowercase, subtext), ' ')
end

function load_webnlg_test_id_cat_map(file)
    mapping = Dict{String, String}()
    doc = readxml(file)
    benchmark = doc.root
    entries = elements(benchmark)[]
    for entry in eachelement(entries)
        category = intern(entry["category"])
        eid = intern(entry["eid"])
        mapping[eid] = category
    end
    return mapping
end

function load_generated_webnlg_file(file; isref = false, mapping = nothing)
    dname = dirname(file)
    filename = basename(file)
    name = first(splitext(filename))
    T = @NamedTuple{
        category::String, eid::String,
        triples::Vector{NTuple{3, String}}}
    dataset = T[]
    dataset_by_cat = Dict{String, Vector{T}}()
    doc = readxml(file)
    benchmark = doc.root
    entries = elements(benchmark)[]
    for entry in eachelement(entries)
        category = intern(entry["category"])
        eid = intern(entry["eid"])
        if isempty(strip(category))
            category = mapping[eid]
        end
        gtriples = NTuple{3, String}[]
        if isref
            for mtriple in findall("modifiedtripleset/mtriple", entry)
                head, edge, tail = split(mtriple.content, " | ")
                head = replace(strip(head, '"'), '_'=>' ')
                tail = replace(strip(tail, '"'), '_'=>' ')
                edge = camelsplit(edge)
                triple_text = (head, edge, tail)
                @assert length(triple_text) == 3
                triple = Tuple(Iterators.map(intern, triple_text))
                push!(gtriples, triple)
            end
        else
            for mtriple in findall("generatedtripleset/gtriple", entry)
                head, edge, tail = split(mtriple.content, " | ")
                head = replace(strip(head, '"'), '_'=>' ')
                tail = replace(strip(tail, '"'), '_'=>' ')
                edge = camelsplit(edge)
                triple_text = (head, edge, tail)
                @assert length(triple_text) == 3
                triple = Tuple(Iterators.map(intern, triple_text))
                push!(gtriples, triple)
            end
            unique!(gtriples)
        end
        push!(dataset, (; category, eid, triples = gtriples))
        cdataset = get!(dataset_by_cat, category, T[])
        push!(cdataset, (; category, eid, triples = gtriples))
    end

    path = mkpath(joinpath(dname, name))
    write_webnlg_file(joinpath(path, "all.xml"), dataset; isref)
    for (category, cdataset) in dataset_by_cat
        write_webnlg_file(joinpath(path, "$category.xml"), cdataset; isref)
    end
    nothing
end

nfkd_nonascii(text) = filter(c->Int(c) < 128, replace(replace(Base.Unicode.normalize(text, :NFKD), '−'=>'-'), '’'=>'''))
punctnum_split(text) = replace(replace(replace(text, r"(?!.\.$)([^[:punct:]\s\d])([[:punct:]])"=>s"\1 \2"), r"([[:punct:]])([^[:punct:]\s\d])"=>s"\1 \2"), r" ([[:punct:]|\d])"=>s"▁\1")
#strip(replace(replace(text, r"([[:punct:]|\d])" => s" \1 "), r"\s+"=> ' '))


function load_webnlg_files(files; istest = false)
    dataset = @NamedTuple{
        category::String, graphsize::Int, eid::String,
        texts::Vector{String}, triples::Vector{NTuple{3, String}}}[]
    for xmlfile in files
        doc = readxml(xmlfile)
        benchmark = doc.root
        entries = elements(benchmark)[]
        for entry in eachelement(entries)
            category = intern(entry["category"])
            graphsize = parse(Int, entry["size"])
            eid = intern(entry["eid"])
            texts = String[]
            modified_triples = NTuple{3, String}[]
            for lex in findall("lex", entry)
                !istest && lex["comment"] != "good" && !isempty(texts) && continue
                push!(texts, punctnum_split(nfkd_nonascii(String(lex.content))))
            end
            for mtriple in findall("modifiedtripleset/mtriple", entry)
                head, edge, tail = split(nfkd_nonascii(mtriple.content), " | ")
                head = replace(strip(head, '"'), '_'=>' ') |> punctnum_split
                tail = replace(strip(tail, '"'), '_'=>' ') |> punctnum_split
                edge = camelsplit(edge) |> punctnum_split
                triple_text = (head, edge, tail)
                @assert length(triple_text) == 3
                triple = Tuple(Iterators.map(intern, triple_text))
                push!(modified_triples, triple)
            end
            push!(dataset, (; category, graphsize, eid, texts, triples = modified_triples))
        end
    end
    return dataset
end

function load_webnlg_dataset(path)
    files = list_webnlg_files(path)
    dataset = load_webnlg_files(files)
    return dataset
end

function load_webnlg_test_dataset(path)
    dataset = load_webnlg_files((path,); istest = true)
    return dataset
end

TrainSet(d::WebNLG2020) = WebNLG2020DataSet(load_webnlg_dataset(joinpath(d.path, "train")))
DevSet(d::WebNLG2020) = WebNLG2020DataSet(load_webnlg_dataset(joinpath(d.path, "dev")))
TestSet(d::WebNLG2020) = WebNLG2020DataSet(load_webnlg_test_dataset(joinpath(d.path, "test/semantic-parsing-test-data-with-refs-en.xml")))

#### data processing

function extract_text(wns::WebNLG2020Sample)
    sample = wns.sample
    return rand(sample.texts)::String
end
extract_graph(wns::WebNLG2020Sample) = triple2graph(wns.sample.triples)

DomainToken(::WebNLG2020Sample) = "[WEBNLG]"

#### test utils

function cameljoin(texts)
    return join(Iterators.map(enumerate(texts)) do (i, x)
        i == 1 ? x : titlecase(x)
    end)
end

function webnlg_graph2xml(preds; isref = false)
    doc = XMLDocument()
    root = ElementNode("benchmark")
    setroot!(doc, root)
    entries = ElementNode("entries")
    link!(root, entries)
    for pred in preds
        entry = ElementNode("entry")
        link!(entries, entry)
        entry["category"] = pred.category
        entry["eid"] = pred.eid
        if isref
            if haskey(pred, :texts)
                lexnode = ElementNode("lex")
                link!(lexnode, TextNode(first(pred.texts)))
                link!(entry, lexnode)
            end
            triplesetnode = ElementNode("modifiedtripleset")
            link!(entry, triplesetnode)
            for (head, edge, tail) in pred.triples
                mnode = ElementNode("mtriple")
                link!(triplesetnode, mnode)
                edge = cameljoin(split(edge, ' '))
                triple = join((head, edge, tail), " | ")
                link!(mnode, TextNode(triple))
            end
        else
            generated = ElementNode("generatedtripleset")
            link!(entry, generated)
            if haskey(pred, :graph)
                triples = processed_graph_to_triple(pred.graph)
            else
                triples = pred.triples
            end
            for (head, edge, tail) in triples
                gnode = ElementNode("gtriple")
                link!(generated, gnode)
                head = strip(replace(join(head), '▁'=>' '))
                tail = strip(replace(join(tail), '▁'=>' '))
                edge = cameljoin(split(strip(replace(join(edge), '▁'=>' ')), ' '))
                triple = join((head, edge, tail), " | ")
                link!(gnode, TextNode(triple))
            end
        end
    end
    return doc
end

function write_webnlg_file(path, preds; isref = false)
    doc = webnlg_graph2xml(preds; isref)
    open(path, "w+") do f
        prettyprint(f, doc)
    end
    return path
end

function webnlg_graph2xml(textenc, preds; isref = false)
    doc = XMLDocument()
    root = ElementNode("benchmark")
    setroot!(doc, root)
    entries = ElementNode("entries")
    link!(root, entries)
    for pred in preds
        entry = ElementNode("entry")
        link!(entries, entry)
        entry["category"] = pred.category
        entry["eid"] = pred.eid
        if isref
            if haskey(pred, :texts)
                lexnode = ElementNode("lex")
                link!(lexnode, TextNode(first(pred.texts)))
                link!(entry, lexnode)
            end
            triplesetnode = ElementNode("modifiedtripleset")
            link!(entry, triplesetnode)
            for (head, edge, tail) in pred.triples
                head = strip(replace(join(map(x->x.x, TextEncodeBase.tokenize(textenc, head))), '▁'=>' '))
                tail = strip(replace(join(map(x->x.x, TextEncodeBase.tokenize(textenc, tail))), '▁'=>' '))
                edge = strip(replace(join(map(x->x.x, TextEncodeBase.tokenize(textenc, edge))), '▁'=>' '))
                mnode = ElementNode("mtriple")
                link!(triplesetnode, mnode)
                edge = cameljoin(split(edge, ' '))
                triple = join((head, edge, tail), " | ")
                link!(mnode, TextNode(triple))
            end
        else
            generated = ElementNode("generatedtripleset")
            link!(entry, generated)
            if haskey(pred, :graph)
                triples = processed_graph_to_triple(pred.graph)
            else
                triples = pred.triples
            end
            for (head, edge, tail) in triples
                gnode = ElementNode("gtriple")
                link!(generated, gnode)
                head = strip(replace(join(head), '▁'=>' '))
                tail = strip(replace(join(tail), '▁'=>' '))
                edge = cameljoin(split(strip(replace(join(edge), '▁'=>' ')), ' '))
                triple = join((head, edge, tail), " | ")
                link!(gnode, TextNode(triple))
            end
        end
    end
    return doc
end

function write_webnlg_file(textenc, path, preds; isref = false)
    doc = webnlg_graph2xml(textenc, preds; isref)
    open(path, "w+") do f
        prettyprint(f, doc)
    end
    return path
end
