using Mmap
using JSON3

#### load helper

function mmap_open(file, fsz = filesize(file))
    return open(f->Mmap.mmap(f, Vector{UInt8}, (Int(fsz), )), file, "r")
end

function read_event_json(path, file)
    filepath = joinpath(path, file)
    @assert isfile(filepath)
    return JSON3.read(mmap_open(filepath))
end

#### Dataset type definition

struct EventNarrative <: DataSetConfig{Union{HasTrainSet, HasDevSet, HasTestSet}}
    path::String
end

trainset_file(::EventNarrative) = "train_data.json"
devset_file(::EventNarrative) = "dev_data.json"
testset_file(::EventNarrative) = "test_data.json"

struct EventNarrativeDataSet{T} <: DataSetType{Union{Indexable, ReIterable}}
    dataset::T
end

struct EventNarrativeSample{T} <: SampleType{Union{HasText, HasGraph}}
    sample::T
end

Base.length(data::EventNarrativeDataSet) = length(data.dataset)
Base.getindex(data::EventNarrativeDataSet, i) = EventNarrativeSample(data.dataset[i])
Base.getindex(data::EventNarrativeDataSet, is::AbstractVector) = map(i->data[i], is)
Base.iterate(data::EventNarrativeDataSet, i = 1) = 0 < i <= length(data) ? (data[i], i+1) : nothing

TrainSet(d::EventNarrative) = EventNarrativeDataSet(read_event_json(d.path, trainset_file(d)))
DevSet(d::EventNarrative) = EventNarrativeDataSet(read_event_json(d.path, devset_file(d)))
TestSet(d::EventNarrative) = EventNarrativeDataSet(read_event_json(d.path, testset_file(d)))

#### data processing

function extract_text(ens::EventNarrativeSample)
    sample = ens.sample
    replacements = (Regex(String(entity))=>value for (entity, value) in sample.entity_ref_dict)
    text = replace(sample.narration, replacements...)::String
    return text
end
extract_graph(ens::EventNarrativeSample) = triple2graph(ens.sample.keep_triples)

DomainToken(::EventNarrativeSample) = "[EVENTNARRATIVE]"
