using Functional
using Distributions
using StatsBase
using Parameters
using SparseArrays
using DataFrames
using DataFramesMeta
using Arpack
using Transducers
using LinearAlgebra

struct LSPEState
    t::Int64
    φ::Function # Feature map
    θ::Vector{Float64} # Linear Q-function parameters
    A::SparseMatrixCSC{Float64, Int64}
    Binv::SparseMatrixCSC{Float64, Int64}
    b::Vector{Float64}
    z::SparseVector{Float64, Int64}
    η::Float64 # Estimate of average reward
end


LSPEState(d::Int64; φ=partial(cat; dims=1), λ=1., α=1.) =
    LSPEState(0, φ, zeros(d), α * sparse(I, d, d), λ .* sparse(I, d, d),
              zeros(d), zeros(d), 0.)
LSPEState(φ::Function; kwargs...) =
    LSPEState(size(φ(1, 1), 1); φ=φ, λ=λ)

predict_linear(lspe::LSPEState, θ::Vector{Float64}, s::Ts, a::Ta) where {Ts, Ta} =
    θ'lspe.φ(s, a)
predict_lspe(lspe::LSPEState, s::Int64, a::Int64) =
    predict_linear(lspe, lspe.θ, s, a)

function update_weighted!(old, old_wt, new, new_wt)
    old *= old_wt
    old += new_wt .* new
    old
end

update_average!(t, old, new) =
    update_weighted!(old, t / (t + 1), new, 1 / (t + 1))
sherman_morrison(Ainv, u, v) =
    - (Ainv * u) * (Ainv'v)' ./ (1 + v'Ainv * u)

update_estimator!(mdp::MDP, lspe::LSPEState, sarsa::SARSA; kwargs...) =
    update_estimator!(lspe, sarsa; kwargs...)

function update_estimator!(lspe::LSPEState, sarsa::SARSA; λ=0., ɣ=1.)
    """
    λ : Eligiblity trace coefficient.
    ɣ : Step size.
    """
    @unpack t, φ, A, Binv, b, z, θ, η = lspe
    @unpack s, a, r, snew, anew = sarsa
    x, xnew = φ(s, a), φ(snew, anew)
    znew = λ * z + x
    ηnew = (t / (t + 1)) * η + 1 / (t + 1) * r
    Binvnew = update_weighted!(
        Binv, (t + 1) / t, sherman_morrison(Binv ./ t, x, x), t + 1)
    Anew = update_average!(t, A, znew * (xnew - x)')
    bnew = update_average!(t, b, znew * (r - ηnew))
    θnew = θ + ɣ .* Binvnew * (Anew * θ + bnew)
    LSPEState(t + 1, φ, θnew, Anew, Binvnew, bnew, znew, ηnew)
end

function summarize_Q(Pco, Ptr, ss, ρ, Q)
    Qco, Qtr = Q.(ss, 1), Q.(ss, 2)
    V = 0.5 * Qco + 0.5 * Qtr
    rco, rtr = rewards(Pco), rewards(Ptr)
    QcoV, QtrV = (rco + Pco * V, rtr + Ptr * V)
    (ρ' * (Qtr - Qco), ρ' * (QtrV - QcoV))
end

function summarize_estimator(mdp::MDP,
                             stats::PathStats,
                             lspe::LSPEState)
    ss = visited_states(stats)
    if length(ss) == 0 return [] end

    ρ = empirical_ρ(stats)[ss]
    Qhat_lspe = partial(predict_linear, lspe, lspe.θ)
    Qtr_lspe, Qco_lspe = Qhat_lspe.(ss, 2), Qhat_lspe.(ss, 1)

    τlstdQ, τlstdV =
        if !(rank(lspe.A) == size(lspe.A, 1))
            0., 0.
        else
            Pco, Ptr, ss = empirical_Ps(stats)
            rco, rtr = rewards(Pco), rewards(Ptr)

            θlstd = - lspe.A \ lspe.b
            Qlstd(s, a) = θlstd'lspe.φ(s, a)
            summarize_Q(Pco, Ptr, ss, ρ, Qlstd)
        end

    [Dict(:t => stats.t,
          :estimator => "DQ-LSPE-Q",
          :estimate => ρ' * (Qtr_lspe - Qco_lspe)),
     Dict(:t => stats.t,
          :estimator => "DQ-LSTD-Q",
          :estimate => τlstdQ),
     # Dict(:t => stats.t,
     #      :estimator => "lspe-V",
     #      :estimate => δ * (0.5 * Qtr_lspe .+ 0.5 * Qco_lspe)),
     Dict(:t => stats.t,
          :estimator => "DQ-LSTD-V",
          :estimate => τlstdV)]
end


# function dq_lspeQ(mdp::MDP, stats::PathStats, lspe::LSPEState)
#     ss = findall(ρ .> 0)
#     Qhat = partial(predict_lspe, lspe)
#     ρ[ss]' * (Qhat.(ss, 2) .- Qhat.(ss, 1))
# end

# lstd(lspe::LSPEState) = - lspe.A \ lspe.b
# function dq_lstdQ(mdp::TSRBirthDeathMDP,
#                   ρ::Vector{Float64}, lspe::LSPEState)
#     ss = findall(ρ .> 0)
#     θhat = lstd(lspe)
#     Qhat = partial(predict_linear, lspe, θhat)
#     ρ[ss]' * (Qhat.(ss, 2) .- Qhat.(ss, 1))
# end
