using PartialFunctions
using Printf
using DataFrames
using DataFramesMeta
using Transducers
using MLStyle
using Distributed
using Codenamize
using CSV
using MLStyle

include("MarkovInterference.jl")

s0(mdp) = (mdp.N - ntr(mdp), ntr(mdp)) |> partial(state2id, mdp)
function MDP(N::Int64, μ::Float64, λ::Float64, design::Symbol)
    aL, aC = @match design begin
        :cr => (1., 0.5)
        :lr => (0.5, 1.)
        :co => (1., 0.)
        :tr => (1., 1.)
        :tsr => (0.5 * (1 - exp(- λ / μ)) + exp(-λ / μ),
                 (1 - exp(- λ / μ)) + 0.5 * exp(-λ / μ))
    end
    TSRBirthDeathMDP(N, aL, aC, μ, λ, [0.315, 0.3937])
end

function RunTSR(mdp, T, estimator; seed=1, T0=5 * num_states(mdp), kwargs...)
    estimator_state = @match estimator begin
        :stats => PathStats(num_states(mdp))
        :linear_lspeQ => LSPEState(linear_φ)
        :linear_lspeV => LSPEState(linear_Vφ)
        :linear_td => LinearTDStates(linear_φ)
        # :linear_tabular_td => LinearTDStates(
        #     (s, a) -> tabular_φ(mdp, s, a))
        :tabular_td => TabularTDStates(num_states(mdp))
        # :tabular_tdV => TabularTDStates(num_states(mdp))
        :linear_gtd => GTDStates(linear_φ)
        :tabular_gtd => GTDStates(partial(tabular_φ, mdp))
        :tabular_lspeQ => LSPEState(partial(tabular_φ, mdp))
        :tabular_lspeV => LSPEState(partial(tabular_Vφ, mdp))
        :tsr => TSRState()
    end
    1:(T + T0) |>
        withprogress |>
        simulate_mdp(MersenneTwister(seed), mdp, s0(mdp)) |>
        Drop(T0) |>
        Zip(Scan(partial(update_estimator!, mdp),
                 PathStats(num_states(mdp))),
            Scan(partial(update_estimator!, mdp; kwargs...),
                 estimator_state))
end

time_checker = log_time_checker()
SummarizeTSR(mdp; kwargs...) = Filter(x -> time_checker(x[1])) |>
    MapSplat(partial(summarize_estimator, mdp; kwargs...)) |>
    Cat()

function run_and_summarize_tsr(; N, T, μ, λ, design, estimator, seed,
                               summ_kwargs=Dict(), kwargs...)
    mdp = MDP(N, μ, λ, design)
    df = RunTSR(mdp, T, estimator; seed=seed, kwargs...) |>
        SummarizeTSR(mdp; summ_kwargs...) |>
        collect |>
        DataFrame
    for (kw, arg) in summ_kwargs df[!, kw] .= arg end
    for (kw, arg) in kwargs df[!, kw] .= arg end
    df[!, :mu] .= μ
    df[!, :lam] .= λ
    df[!, :N] .= N
    df[!, :seed] .= seed
    df[!, :design] .= String(design)
    df[!, :config] .= String(estimator)
    for (arg, val) in summ_kwargs
        df[!, arg] .= val
    end
    df
end

function run_summarize_and_save_tsr(; seed, path, kwargs...)
    fname = @sprintf("%s/%s_s%d.csv",
                     path, codenamize(kwargs; adjectives=2), seed)
    @show fname
    if !isfile(fname)
        df = run_and_summarize_tsr(; seed=seed, kwargs...)
        CSV.write(fname, df)
    end
end

function ηtrue(mdp)
    P = transition_matrix(mdp)
    ρ = stationary_distribution(P)
    r = rewards(P)
    r'ρ
end

τtrue(N, μ, λ) = ηtrue(MDP(N, μ, λ, :tr)) - ηtrue(MDP(N, μ, λ, :co))
