VERBOSE = false

#########################################
#            Solver
#########################################

@kwdef struct RHSVISolver <: Solver 
    max_time::Float64       = 60.0              # Timeout time (seconds)
    max_depth::Int          = 150
    init_iters::Int         = 10
    max_iters::Int          = 25_000
    epsilon::Float64        = 0.02              # (Relative) precision
    heuristic_solver        = RFIBSolver()     # Solver used to compute initial        # TODO: solve bug FIB!
    lowerbound_solver       = ZeroAlphas()
end

# TODO: add outer loop to build a new tree after this one is sufficiently accurate
function POMDPs.solve(solver::RHSVISolver, env::X) where X<:POMDP
    t0 = time()
    VERBOSE && println("Initializing...")
    tree = initialize_RHSVITree(env, solver)

    i,j = 1, 0
    while (rel_value_gap(tree, 1) > solver.epsilon) && (time()-t0 < solver.max_time) && (i < solver.max_iters)
        VERBOSE && println("\nIteration $i (Vs: $(tree.Vlower[1]), $(tree.Vupper[1])):")
        sampled_bidxs = [1]
        bidx = 1
        h = 0
        while ( value_gap(tree, bidx) * (discount(tree.env)^(h)) >= tree.Vlower[1] * solver.epsilon &&
                h < solver.max_depth &&
                !isterminalbelief(tree.env, tree.B[bidx]))
            bidx = explore(tree, bidx)
            push!(sampled_bidxs, bidx)
            h+=1
        end
        # VERBOSE && println("Sampled $(length(sampled_bidxs)) beliefs.")
        # VERBOSE && println("Sampled the following beliefs: $(map(bpidx -> tree.B[bpidx], sampled_bidxs))")
        # VERBOSE && println("Sampled the following beliefs: $sampled_bidxs")
        backup!(tree, sampled_bidxs)
        # VERBOSE && println(tree.Vlower[sampled_bidxs])
        VERBOSE && println("Backup done!")
        prune!(tree)
        VERBOSE && println("Pruning done! (Currently $(length(tree.Alphas)) alphas and $(count(tree.B_pointset)) pointset beliefs)")
        i += 1
        # TODO: think about how tis outer loop should work...
        # for a in tree.Alphas 
        #     println(a)
        # end
        if (mod(i,solver.init_iters * 2^j) == 0)
            VERBOSE && println("Resetting tree at $i iterations:")
            Alphas = pruneAlphas(tree.Alphas, tree.B[tree.B_pointset])
            tree = initialize_RHSVITree(env,solver;Alphas=Alphas, Vs_init=tree.Vsupper)
            j += 1
        end
    end
    VERBOSE && println("Done! (In $i iterations, with Vs: $(tree.Vlower[1]), $(tree.Vupper[1]))")
    Alphas = pruneAlphas(tree.Alphas, tree.B[tree.B_pointset])
    # println(length(Alphas))
    #println(tree.Alphas)
    return RobustAlphaVectorPolicy(tree.env, tree.Alphas)
end

#########################################
#               Belief Tree
#########################################

struct SucessorBelief
    bidx::Int
    prob::Float64
    oidx::Int
end

@kwdef mutable struct RHSVITree
    env
    C::C

    B::Vector{DiscreteHashedBelief}             = [] 
    Bps::Vector{Vector{Vector{SucessorBelief}}} = []
    Vupper::Vector{Float64}                     = []
    Vsupper::Vector{Float64}                    = []
    Vlower::Vector{Float64}                     = []
    Qupper::Vector{Vector{Float64}}             = []
    Qlower::Vector{Vector{Float64}}             = []
    Uncertainty::Vector{Float64}                = []

    Alphas::Vector{AlphaVector{<:Any}}          = []
    Alphas_protected::Int                       = 1
    B_pointset::BitVector                       = []
    B_expanded::BitVector                       = []
    lb_has_updated::Bool                        = false
end

function initialize_RHSVITree(env::X, solver::RHSVISolver; Alphas=[], Vs_init=nothing) where X<:POMDP
    
    constants = get_constants(env)
    if Vs_init isa Nothing
        heuristic_policy = solve(solver.heuristic_solver, env)
        Vs_init = get_exterior_values(heuristic_policy)
    end
    append!(Alphas, solve(solver.lowerbound_solver, env).alphas)

    tree = RHSVITree(
        env=env,
        C=constants,
        Vsupper=Vs_init,
        Alphas = Alphas,
        Alphas_protected = length(Alphas)
    )
    b0 = DiscreteHashedBelief(initialstate(tree.env))
    initialize_node(tree, b0)
    return tree
end

#########################################
#               Control Flow
#########################################
"""
The exploration function from HSVI (Alg. 2). 
Heuristicall chooses next sucessor belief of bidx to explore.
"""
function explore(tree, bidx)
    !(tree.B_expanded[bidx]) && expand_node!(tree, bidx)
    aidx = argmax(tree.Qupper[bidx])
    return uncertain_belief(tree, bidx, aidx)
end
"""
Recomputes upper- and lower bounds for all beliefs visited in counter-chronological order
"""
function backup!(tree::RHSVITree, sampled_bidxs::Vector)
    bidxlast = sampled_bidxs[end]
    !tree.B_expanded[bidxlast] && expand_node!(tree, sampled_bidxs[end])
    tree.Uncertainty[bidxlast] = 0.0
    for bidx in reverse(sampled_bidxs)
        backup!(tree, bidx)
    end
end

"""
Prune beliefs and alpha vectors
"""
function prune!(tree::RHSVITree)
    # HSVI only prunes these sporadically, but we prune them after each iteration, 
    # since the complexity of our exploration grows massively with |Alphas|
    prune_beliefs!(tree)
    # tree.Alphas = pruneAlphas(tree.Alphas, tree.B[tree.B_pointset], alphas_protected=tree.Alphas_protected) # using implementation form PBVI 
end

#########################################
#              Nodes Expansion
#########################################

"""
Initializes a belief node with belief b: return the index of the node.
"""
function initialize_node(tree, b)

    push!(tree.B, b)
    bidx = length(tree.B)
    push!(tree.Bps, [])
    push!(tree.B_expanded, false)
    push!(tree.B_pointset, false)

    Vlower, Vupper = bounds(tree, bidx)
    push!(tree.Vlower, Vlower)
    push!(tree.Vupper, Vupper)
    push!(tree.Uncertainty, Vupper - Vlower)
    push!(tree.Qlower, [])
    push!(tree.Qupper, [])

    return bidx
end

"""
Expands a belief node: computes successor beliefs & Q-values.
"""
function expand_node!(tree, bidx)
    for (aidx,a) in enumerate(tree.C.A)
        push!(tree.Bps[bidx], [])
        Qlower, alpha, Bdist = backup(tree.env, tree.B[bidx], a, tree.Alphas)
        for ((o,bp),p) in weighted_iterator(Bdist)
            bpidx = initialize_node(tree, bp)
            oidx = findfirst(isequal(o), tree.C.O)
            push!(tree.Bps[bidx][aidx], SucessorBelief(bpidx, p, oidx))
        end
        push!(tree.Qlower[bidx], Qlower)
        push!(tree.Qupper[bidx], upperbound(tree, bidx, aidx))
    end
    tree.B_expanded[bidx] = true
    tree.B_pointset[bidx] = true
end

"""
Returns bidx if there is already a node with belief b, and nothing otherwise.
"""
belief_exists(tree, b) = nothing

#########################################
#              Belief Pruning
#########################################

function prune_beliefs!(tree::RHSVITree)
    for (bidx, b) in enumerate(tree.B)
        # Ignore if not expanded, already pruned or initial belief
        (!(tree.B_expanded[bidx]) || !(tree.B_pointset[bidx]) || bidx==1) && continue

        # Condition 1: prune sucessor beliefs if action is suboptimal
        Vlower = tree.Vlower[bidx]
        for (aidx, a) in enumerate(tree.C.A)
            if (tree.Qupper[bidx][aidx] < Vlower)
                prune_subtree!(tree, bidx, aidx)
            end
        end
    end
end

function prune_subtree!(tree, bidx, aidx)
    if tree.B_expanded[bidx] && tree.B_pointset[bidx]
        for succbelief in tree.Bps[bidx][aidx]
            bpidx = succbelief.bidx
            for apidx in 1:tree.C.na
                prune_subtree!(tree, bpidx, apidx)
            end
        end
    end
    # tree.B_pointset[bidx] = false
end

#########################################
#              Value bounds
#########################################
bounds(tree, bidx) = (lowerbound(tree,bidx), upperbound(tree, bidx))
"""
Returns the lower value bound for a  belief b.
"""
lowerbound(tree, bidx) = (maximum(alpha -> dot(alpha, tree.B[bidx]), tree.Alphas))

"""
Returns the upper value bound for a belief b.
"""
function upperbound(tree, bidx::Int)
    isterminalbelief(tree.env,tree.B[bidx]) && return 0.0
    tree.B_expanded[bidx] ? (Vup = maximum(aidx -> upperbound(tree, bidx, aidx), 1:tree.C.na)) : (Vup = Inf)
    # return upperbound_VMDP(tree, bidx)
    return min(Vup, sawtooth(tree, bidx))
end

upperbound_VMDP(tree, bidx) = sum(s->pdf(tree.B[bidx],s)*tree.Vsupper[stateindex(tree.env, s)], support(tree.B[bidx]))

# sawtooth(tree, bidx::Int) = sawtooth(tree, tree.B[bidx])
function sawtooth(tree, bidx::Int)
    b = tree.B[bidx]
    alpha_corner = AlphaVector(tree.Vsupper, collect(states(tree.env)), nothing)
    # tree.B_expanded[bidx] ? (Vb = tree.Vupper[bidx]) : (Vb = dot(alpha_corner, b))
    Vb = dot(alpha_corner, b)
    Vmin = Vb
    for bint_idx in (1:length(tree.B))[tree.B_pointset]
        bint_idx == bidx && continue
        bint, vint = tree.B[bint_idx], tree.Vupper[bint_idx]
        ratio = min_ratio(b,bint)
        thisV = Inf
        if true #ratio > 0.0 #&& ratio < 1.0 
            thisV = Vb + ratio * (vint - dot(alpha_corner, bint))
            if thisV < Vmin 
                Vmin = thisV
                # println(ratio)
                # println()
            end
        end
    end
    return Vmin
end

function min_ratio(b::DiscreteHashedBelief,bp::DiscreteHashedBelief)
    minratio = Inf
    bidx = 1
    n_sup_b = length(b.state_list)
    n_sup_bp = length(bp.state_list)
    # n_sup_b != n_sup_bp && return 0.0
    bidx, bpidx = 1, 1
    while bidx <= n_sup_b && bpidx <= n_sup_bp
        sb, sbp = b.state_list[bidx], bp.state_list[bpidx]
        if sb == sbp
            minratio = min(minratio, b.probs[bidx] / bp.probs[bpidx])
            bidx += 1; bpidx += 1
        elseif objectid(sb) < objectid(sbp)
            bidx += 1
        else
            return 0.0
        end
    end
    bidx >= n_sup_b && bpidx <= n_sup_bp ? (return 0.0) : (return minratio)
end


"""
Returns the upper Q-value bound on belief b and action a.
"""
function upperbound(tree, bidx, aidx)
    Qupper = beliefreward(tree.env, tree.B[bidx], tree.C.A[aidx])
    for succbelief in tree.Bps[bidx][aidx]
        bpidx, p = succbelief.bidx, succbelief.prob
        Qupper += p * discount(tree.env) * tree.Vupper[bpidx]
    end
    return Qupper
end

function uncertainty(tree, bidx)
    uncertainty = 0.0
    for aidx in 1:tree.C.na
        for succbelief in tree.Bps[bidx][aidx]
            bpidx, p = succbelief.bidx, succbelief.prob
            uncertainty += p * discount(tree.env) * (tree.Uncertainty[bpidx])
        end
    end
    return uncertainty
end

"""
Compute the excess uncertainty for a belief-action-observation tuple.
"""
function uncertain_belief(tree, bidx, aidx)
    return argmax(bp -> bp.prob * tree.Uncertainty[bp.bidx], tree.Bps[bidx][aidx]).bidx
end
"""
Return the absolute value gap for a given belief.
"""
value_gap(tree, bidx) = (tree.Vupper[bidx] - tree.Vlower[bidx])
"""
Return the relative value gap for a given belief.
"""
rel_value_gap(tree, bidx) = value_gap(tree, bidx) / abs(max(tree.Vlower[bidx])) #TODO: this breaks if Vupper = 0.0: think of how to fix this nicely!
"""
Update both bounds on  (Q-)values for the given belief, and add 
"""
function backup!(tree, bidx; add_alphas=true)
    alphas, Vlower = Vector{AlphaVector}[], -Inf
    for (aidx, a) in enumerate(tree.C.A)
        # Lower Q-values
        if add_alphas
            Qlower, alpha, Bdist = backup(tree.env, tree.B[bidx],a, tree.Alphas)
            # Qlower = Qlower - 0.001
            tree.Qlower[bidx][aidx] = Qlower
            if abs((Qlower - Vlower)/Vlower) < 0.05
                Vlower = max(Qlower, Vlower)
                push!(alphas, alpha)
            elseif Qlower > Vlower
                Vlower = Qlower
                alphas = [alpha]
            end
        end
        # Upper Q-values
        tree.Qupper[bidx][aidx] = upperbound(tree, bidx, aidx)
    end
    tree.Alphas, isupdated = addDominantAlphas(alphas, tree.Alphas, tree.B[tree.B_pointset], alphas_protected=tree.Alphas_protected)
    # append!(tree.Alphas, alphas)
    isupdated && (tree.lb_has_updated = true) # In reality isupdated is almost always true...
    tree.Vlower[bidx] = Vlower
    tree.Vupper[bidx] = min(maximum(tree.Qupper[bidx]), upperbound(tree, bidx))
    tree.Uncertainty[bidx] = min(tree.Uncertainty[bidx], uncertainty(tree, bidx), tree.Vupper[bidx] - tree.Vlower[bidx])
    if tree.Vlower[bidx] > tree.Vupper[bidx] * 1.005
        println("Error: invalid bounds!")
        println("b=$(tree.B[bidx]), bidx=$bidx V=[$(tree.Vlower[bidx]), $(tree.Vupper[bidx])")
        println("Qs=$(tree.Qlower[bidx])")
    end
end
