struct PathStats
    t::Int64
    sum_rewards::Matrix{Float64}
    control_transitions::SparseMatrixCSC{Int64, Int64}
    treatment_transitions::SparseMatrixCSC{Int64, Int64}
end

PathStats(N) = PathStats(0, zeros(Float64, (N, 2)), spzeros(N, N), spzeros(N, N))

function update_estimator!(mdp::TSRBirthDeathMDP,
                           stats::PathStats, sarsa::SARSA)
    @unpack t, sum_rewards, control_transitions, treatment_transitions = stats
    @unpack s, a, r, snew = sarsa
    sum_rewards[s, a] += r
    if a == 1
        control_transitions[s, snew] += 1
    elseif a == 2
        treatment_transitions[s, snew] += 1
    end
    PathStats(t + 1,
              sum_rewards,
              control_transitions,
              treatment_transitions)
end

empirical_ρ(stats::PathStats) =
    vec(sum(stats.control_transitions
            + stats.treatment_transitions; dims=2)) ./
    sum(stats.control_transitions + stats.treatment_transitions)

function is_visited(counts::SparseMatrixCSC; thresh=0)
    out_counts = sum(counts; dims=2) |> vec
    in_counts = sum(counts; dims=1) |> vec
    (in_counts .> thresh) .& (out_counts .> thresh)
end

visited_states(stats::PathStats; thresh=0) =
    is_visited(stats.control_transitions + stats.treatment_transitions;
               thresh=thresh) |> findall

rewards(P::SparseMatrixCSC{Float64, Int64}) =
    vcat(0., [P[s, s-1] for s in 2:size(P, 1)])

stationary_distribution(P::SparseMatrixCSC{Float64, Int64}; method=:arpack) =
    if method == :arpack
        vals, vecs = eigs(P'; nev=2, ncv=100, maxiter=1000)
        real.(vecs[:, 1] ./ sum(vecs[:, 1]))
    elseif method == :arnoldi
        decomp, history = partialschur(collect(P'); nev=2, which=LM())
        vals, vecs = partialeigen(decomp)
        real.(vecs[:, 2] ./ sum(vecs[:, 2]))
    elseif method == :eigen
        vals, vecs = eigen(Matrix(P)')
        real.(vecs[:, end] ./ sum(vecs[:, end]))
    end

function estimate_rewards(stats::PathStats)
    rs = stats.sum_rewards ./
        hcat(sum(stats.control_transitions; dims=2),
             sum(stats.treatment_transitions; dims=2))
    rs[isnan.(rs)] .= 0
    rs
end

function estimate_transition_matrix(counts; fill=false, thresh=0)
    # Assume no absorbing states
    visited = is_visited(counts; thresh=thresh)
    # sub_counts = counts[visited, visited]
    # out_counts = sum(sub_counts; dims=2) |> vec
    out_counts = sum(counts; dims=2) |> vec
    # Phat = sub_counts ./ out_counts
    Phat = spdiagm(1 ./ out_counts[visited]) *
        counts[visited, visited]
    if fill
        N = size(counts, 1)
        Pfilled = spzeros(N, N)
        Pfilled[visited, visited] .= Phat
        Pfilled
    else
        Phat
    end
end

function lstd_off_policy(Pn, r)
    ns = sum(Pn; dims=2) |> vec
    A = hcat(spdiagm(ns) .- Pn, ns)
    spI = SparseMatrixCSC{Float64, Int64}(I, size(A, 1), size(A, 1) + 1)
    Vη = A \ r
    Vη[1:end-1], Vη[end]
end

function lstd_on_policy(Pn, r, rbar; α=0.)
    ns = sum(Pn; dims=2) |> vec
    A = spdiagm(ns) .- Pn
    V = (A + α * I) \ (r .- ns .* rbar)
    V
end


function summarize_estimator(mdp::TSRBirthDeathMDP, stats0::PathStats,
                             stats::PathStats; thresh=0, α=0.)
    ss = visited_states(stats; thresh=thresh)
    Pco = estimate_transition_matrix(
        stats.control_transitions[ss, ss];
        fill=true, thresh=thresh)
    P_tr = estimate_transition_matrix(
        stats.treatment_transitions[ss, ss];
        fill=true, thresh=thresh)

    rs = estimate_rewards(stats)
    r = 0.5 * rs[:, 1] + 0.5 * rs[:, 2]
    ρ = empirical_ρ(stats)

    Cco = stats.control_transitions[ss, ss]
    Ctr = stats.treatment_transitions[ss, ss]
    trans = Cco + Ctr
    r = sum(stats.sum_rewards; dims=2)[ss]
    rbar = sum(r) / sum(trans)

    transQ = [(0.5 * Cco) (0.5 * Cco)
              (0.5 * Ctr) (0.5 * Ctr)]
    rQ = reshape(stats.sum_rewards[ss, :], 2 * length(ss))

    τlstdQ = try
        Qπ = lstd_on_policy(transQ, rQ, rbar; α=α)
        Qπ = reshape(Qπ, (length(ss), 2))
        τlstdQ = ρ[ss]' * (Qπ[:, 2] - Qπ[:, 1])
    catch e
        if e isa SingularException # || e isa LAPACKException
            NaN
        else
            throw(e)
        end
    end

    Vco, ηco = lstd_off_policy(
        stats.control_transitions[ss, ss],
        stats.sum_rewards[ss, 1])
    Vtr, ηtr = lstd_off_policy(stats.treatment_transitions[ss, ss],
                               stats.sum_rewards[ss, 2])
    τoff = ηtr - ηco

    [
     Dict(:t => stats.t,
          :estimator => "Tabular LSTD (Q)",
          :estimate => τlstdQ,
          :thresh => thresh),
     Dict(:t => stats.t,
          :estimator => "Off-Policy Tabular LSTD",
          :estimate => τoff,
          :thresh => thresh)
     ]
end
