
using CSV
using DataFrames
using DataFramesMeta

Turing.@model function bradleyterry(y, n_players, winning_player, loosing_player)
    # Black Box Variational Inference with a Deterministic Objective:
    # Faster, More Accurate, and Even More Black Box. (arXiv:2304.05527v1 [cs.LG])
    
    σ ~ Turing.TruncatedNormal(0, 1, 0, Inf)
    θ ~ MvNormal(zeros(n_players), σ*σ*I)
    p = θ[winning_player] - θ[loosing_player]
    @. y ~ Turing.BernoulliLogit(p)
end

function atp_tennis_dataset()
    #
    # https://datahub.io/sports-data/atp-world-tour-tennis-data
    #

    df  = DataFrame(CSV.File(datadir("datasets", "atp-tennis", "match_scores_1968-1990_unindexed.csv")))
    df′ = DataFrame(CSV.File(datadir("datasets", "atp-tennis", "match_scores_1991-2016_unindexed.csv")))
    df  = vcat(df, df′)
    df′ = DataFrame(CSV.File(datadir("datasets", "atp-tennis", "match_scores_2017_unindexed.csv")))
    df  = vcat(df, df′)
    df  = @chain df begin
        @select(:winner_name, :loser_name)
        @transform(:winner_name = replace.(:winner_name, Ref(r"  +" => " ")))
        @transform(:loser_name  = replace.(:loser_name,  Ref(r"  +" => " ")))
    end
    player_names  = unique(vcat(Array(df[:, :winner_name]),
                                Array(df[:, :loser_name])))
    player_index       = 1:length(player_names)
    index_lookup       = Dict(player_names .=> player_index)
    name_lookup        = Dict(player_index .=> player_names )
    player_lookup_func = name -> index_lookup[name]
    df = @chain df begin
        @transform(:winner = player_lookup_func.(:winner_name))
        @transform(:loser  = player_lookup_func.(:loser_name))
        @select(:winner, :loser)
    end
    winners = df[:,:winner]
    loosers = df[:,:loser]
    winners, loosers, name_lookup, index_lookup
end

function model_with_dataset(::Val{:bradleyterry}, ::Any, batchsize; rng=Random.GLOBAL_RNG)
    winners, loosers, name_lookup, _ = atp_tennis_dataset()
    n_data  = length(winners)
    n_batch = min(batchsize, n_data)

    data_idx  = 1:n_data
    data_itr  = Iterators.partition(data_idx, batchsize)
    batch_idx = first(data_itr)

    winners_dev = winners |> CuArray
    loosers_dev = loosers |> CuArray

    y        = trues(n_data)
    y_batch  = trues(n_batch)

    model    = bradleyterry(y_batch, length(name_lookup), winners[batch_idx], loosers[batch_idx])
    context  = DynamicPPL.MiniBatchContext(; batch_size=n_batch, npoints=n_data)

    varinfo = DynamicPPL.VarInfo(model)
    b       = Bijectors.bijector(model)
    b⁻¹     = inverse(b)
    prob    = DynamicPPL.LogDensityFunction(model, varinfo, context)

    #model_full = bradleyterry_faster(y, length(name_lookup), winners, loosers)
    #prob_full  = DynamicPPL.LogDensityFunction(model_full)

    prior_ctxt = DynamicPPL.PriorContext()
    prior_prob = DynamicPPL.LogDensityFunction(model, varinfo, prior_ctxt)

    function sample_train_batch!(_prob)
        if isempty(data_itr)
            data_itr = Iterators.partition(shuffle(rng, data_idx), n_batch)
        end
        batch_idx, data_itr = Iterators.peel(data_itr)

        _prob = @set _prob.model.args.winning_player = winners[batch_idx]
        _prob = @set _prob.model.args.loosing_player = loosers[batch_idx]
        _prob = @set _prob.model.args.y              = trues(length(batch_idx))
        _prob = @set _prob.context.loglike_scalar    = n_data/length(batch_idx)
        _prob 
    end

    function compute_full_elbo!(q, b⁻¹, M)
        zs, ∑logdetjac = rand_and_logjac(rng, q, b⁻¹, M)

        samples = unflatten_and_stack(zs, prob, [:θ, :σ])
        θ_dev   = samples[:θ] |> Array{Float32} |> CuArray
        σ_dev   = samples[:σ] |> Array{Float32} |> CuArray
        θ       = samples[:θ]
        σ       = samples[:σ][1,:]

        𝔼ℓp_σ = mapreduce(+, 1:M) do m
            logpdf(Turing.TruncatedNormal(0, 1, 0, Inf), σ[m])
        end / M

        d_θ          = size(θ_dev,1)
        ∑ℓp_θ_Z      = d_θ*(-mapreduce(log, +, σ_dev) - log(2*π)/2*M)
        ∑ℓp_θ_unnorm = -sum((θ_dev ./ σ_dev).^2)/2
        𝔼ℓp_θ        = (∑ℓp_θ_unnorm + ∑ℓp_θ_Z)/M
        𝔼ℓprior      = 𝔼ℓp_θ + 𝔼ℓp_σ

        p     = θ[winners,:] - θ[loosers,:]
        𝔼ℓlike = mapreduce(+, 1:M) do m
            mapreduce(+, 1:n_data) do n
                logpdf(Turing.BernoulliLogit(p[n,m]), y[n])
            end / M
        end
        GC.gc()
        𝔼ℓlike + 𝔼ℓprior + entropy(q) + ∑logdetjac/M
    end

    function validate(z)
        vi_new = DynamicPPL.unflatten(prob.varinfo, prob.context, z)
        θ      = vi_new.metadata.θ.vals
        p      = θ[winners] - θ[loosers]

        y_pred = logistic.(p) .> 0.5
        mean(y_pred)
    end

    prob, b⁻¹, sample_train_batch!, compute_full_elbo!, validate
end
