struct GTDState
    t::Int64
    φ::Function # State-action representation
    θ::Vector{Float64}  # Linear coefficients
    ν::Vector{Float64}  # Dual for θ
    η::Float64 # Average reward estimate
    ν1::Float64 # Dual for η
end

GTDState(φ::Function) = GTDState(0, φ,
                                 zeros(size(φ(1, 1), 1)),
                                 zeros(size(φ(1, 1), 1)),
                                 0., 0.)

struct GTDStates
    t::Int64
    Vco::GTDState  # Running off-policy estimate of V_π0
    Vtr::GTDState  # Running off-policy estimate of V_π1
    Vπ::GTDState  # Running on-policy estimate of V_π{1/2}
end

GTDStates(φ::Function) = GTDStates(0, [GTDState(φ) for _ in 1:3]...)

function update_estimator!(mdp::TSRBirthDeathMDP,
                           td::GTDState, sarsa::SARSA;
                           a_target=nothing, lr=0.1, β=0.01, ɣ=0.)
    """
    Uses the Differential TD approach of
    http://proceedings.mlr.press/v139/wan21a/wan21a.pdf
    """
    @unpack s, a, r, snew, anew = sarsa
    @unpack t, φ, θ, ν, η, ν1 = td
    a_target = if a_target == nothing a_target else anew end
    x, xnew = φ(s, a), φ(snew, anew)
    δ = r - η + (xnew - x)'θ
    lrt = lr * (t + 1) ^ (- ɣ)
    ytν = x'ν + η * ν1
    θ += lrt * ((x - xnew) * ytν - β * θ)
    ν += lrt * (δ - ytν) .* x
    ν1new = ν1 + lrt * (δ - ytν) * ν1
    ηnew = η + lrt * ytν
    GTDState(t + 1, φ, θ, ν, ηnew, ν1new)
end

update_estimator!(mdp::TSRBirthDeathMDP, tds::GTDStates, sarsa::SARSA;
                  kwargs...) =
    GTDStates(tds.t + 1,
             update_estimator!(mdp, tds.Vco, sarsa;
                               a_target=1, kwargs...),
             update_estimator!(mdp, tds.Vtr, sarsa;
                               a_target=2, kwargs...),
             update_estimator!(mdp, tds.Vπ, sarsa; kwargs...))

function summarize_estimator(mdp::TSRBirthDeathMDP, stats::PathStats,
                             tds::GTDStates)
    @unpack t, Vco, Vtr, Vπ = tds
    @unpack φ, θ = Vπ
    ss = visited_states(stats)
    if length(ss) == 0 return [] end
    Pco, Ptr, ss = empirical_Ps(stats)
    ρ = empirical_ρ(stats)[ss]
    Q(s, a) = θ'φ(s, a)
    τdq, τdqV = summarize_Q(Pco, Ptr, ss, ρ, Q)

    [Dict(:t => t, :estimator => "Off-Policy GTD", :estimate => Vtr.η - Vco.η),
     Dict(:t => t, :estimator => "DQ-GTD-Q", :estimate => τdq),
     Dict(:t => t, :estimator => "DQ-GTD-V", :estimate => τdqV)]
end
