abstract type TDState end

struct LinearTDState <: TDState
    t::Int64
    φ::Function # State-action representation
    θ::Vector{Float64}  # Linear coefficients
    θpolyak::Vector{Float64}
    η::Float64 # Average reward estimate
    ηpolyak::Float64
end

LinearTDState(φ::Function, d::Int64) = LinearTDState(
    0, φ, zeros(d), zeros(d), 0., 0.)
LinearTDState(φ::Function) = LinearTDState(φ, size(φ(1, 1), 1))

Q(td::LinearTDState, s::Int64, a::Int64; polyak=false) =
    td.φ(s, a)' * (if polyak td.θpolyak else td.θ end)

struct TabularTDState <: TDState
    t::Int64
    θ::Matrix{Float64}        # Value Function Coefficients
    θpolyak::Matrix{Float64}  # Linear coefficients
    η::Float64 # Average reward estimate
    ηpolyak::Float64
end
Q(td::TabularTDState, s::Int64, a::Int64; polyak=false) =
    if polyak td.θpolyak[s, a] else td.θ[s, a] end

function TabularTDState(N::Int64)
    θ0 = randn((N, 2))
    η0 = randn()
    TabularTDState(0, θ0, θ0, η0, η0)
end

struct TDStates{T<:TDState}
    t::Int64
    Vco::T # Running off-policy estimate of V_π0
    Vtr::T # Running off-policy estimate of V_π1
    Vπ::T # Running on-policy estimate of V_π{1/2}
end

LinearTDStates(φ::Function) =
    TDStates(0, [LinearTDState(φ) for _ in 1:3]...)
LinearTDStates(d::Int64; φ=partial(cat; dims=1)) =
    TDStates(0, [LinearTDState(φ, d) for _ in 1:3]...)
TabularTDStates(N::Int64) = TDStates(0, [TabularTDState(N) for _ in 1:3]...)

function update_estimator!(mdp::MDP,
                           td::LinearTDState, sarsa::SARSA;
                           π::SVector{2, Float64}, lr=0.1, β=0.5, ɣ=0.)
    """
    Uses the Differential TD approach of
    http://proceedings.mlr.press/v139/wan21a/wan21a.pdf
    """
    @unpack s, a, r, snew, anew = sarsa
    @unpack t, φ, θ, θpolyak, η, ηpolyak = td
    b = (1 - mdp.aC, mdp.aC)
    ipw = (π[a] / b[a]) * (π[anew] / b[anew])
    x, xnew = φ(s, a), φ(snew, anew)
    td_error = r - η + (xnew - x)'θ
    lrt = lr * (t + 1) ^ (- ɣ)
    θ .+= (ipw * lrt * td_error) .* x
    ηnew = η + β * lrt * ipw * td_error

    θpolyak .*= (t / (t + 1))
    θpolyak .+= (1 / (t + 1)) .* θ
    ηpolyak_new = (t / (t + 1)) * ηpolyak + (1 / (t + 1)) * ηnew

    LinearTDState(t + 1, φ, θ, θpolyak, ηnew, ηpolyak_new)
end

function update_estimator!(mdp::MDP,
                           td::TabularTDState, sarsa::SARSA;
                           π::SVector{2, Float64}, lr=0.1, β=0.5, ɣ=0.)
    """
    Should be identical to method for LinearTDState with a tabular φ, but way faster.
    """
    @unpack s, a, r, snew, anew = sarsa
    @unpack t, θ, θpolyak, η, ηpolyak = td
    b = (1 - mdp.aC, mdp.aC)
    ipw = (π[a] / b[a]) * (π[anew] / b[anew])
    td_error = r - η + θ[snew, anew] - θ[s, a]
    lrt = lr * (t + 1) ^ (- ɣ)
    θ[s, a] += ipw * lrt * td_error
    ηnew = η + β * lrt * ipw * td_error

    # θpolyak[s, a] *= t / (t + 1)
    # θpolyak[s, a] += (1 / (t + 1)) * θ[s, a]
    # θpolyak .= (t / (t + 1)) .* θpolyak .+ (1 / (t + 1)) .* θ
    # θpolyak[s, a] = (t / (t + 1)) .* θpolyak[s, a] .+ (1 / (t + 1)) .* θ[s, a]
    θpolyak .*= (t / (t + 1))
    θpolyak .+= (1 / (t + 1)) .* θ
    ηpolyak_new = (t / (t + 1)) * ηpolyak + (1 / (t + 1)) * ηnew

    TabularTDState(t + 1, θ, θpolyak, ηnew, ηpolyak_new)
end


update_estimator!(mdp::MDP, tds::TDStates, sarsa::SARSA;
                  kwargs...) =
    TDStates(tds.t + 1,
             update_estimator!(mdp, tds.Vco, sarsa; π=SVector(1., 0.), kwargs...),
             update_estimator!(mdp, tds.Vtr, sarsa; π=SVector(0., 1.), kwargs...),
             update_estimator!(mdp, tds.Vπ, sarsa;
                               π=SVector(1 - mdp.aC, mdp.aC), kwargs...))

function summarize_estimator(mdp::TSRBirthDeathMDP, stats::PathStats,
                             tds::TDStates)
    @unpack t, Vco, Vtr, Vπ = tds
    ss = visited_states(stats)
    if length(ss) == 0 return [] end
    Pco = estimate_transition_matrix(
        stats.control_transitions[ss, ss], fill=true)
    Ptr = estimate_transition_matrix(
        stats.treatment_transitions[ss, ss], fill=true)
    ρ = empirical_ρ(stats)[ss]

    τdq, τdqV = summarize_Q(
        Pco, Ptr, ss, ρ, (s, a) -> Q(Vπ, s, a; polyak=false))
    τdq_polyak, τdqV_polyak = summarize_Q(
        Pco, Ptr, ss, ρ, (s, a) -> Q(Vπ, s, a; polyak=true))

    [Dict(:t => t, :estimator => "Off-Policy TD",
          :estimate => Vtr.η - Vco.η),
     Dict(:t => t, :estimator => "Off-Policy TD (Polyak)",
          :estimate => Vtr.ηpolyak - Vco.ηpolyak),
     Dict(:t => t, :estimator => "DQ-TD-Q", :estimate => τdq),
     Dict(:t => t, :estimator => "DQ-TD-Q (Polyak)",
          :estimate => τdq_polyak),
     Dict(:t => t, :estimator => "DQ-TD-V", :estimate => τdqV),
     Dict(:t => t, :estimator => "DQ-TD-V (Polyak)",
          :estimate => τdqV_polyak)]
end
