##################################################################
#                       Alpha Vectors 
##################################################################

struct AlphaVector{S<:Any}
    α::Vector{Float64}
    states::Vector{S} # assumed sorted!
    action::Any
    hash::UInt
end
function AlphaVector(α::Vector{Float64}, states::Vector{S}, action) where S<:Any 
    idxs = sortperm(states; lt= (x,y) -> objectid(x) < objectid(y))
    states, α = states[idxs], α[idxs]
    hash = makeDBhash(states, α)
    return AlphaVector(α, states, action, hash)
end

function Base.getindex(a::AlphaVector{S}, s::S) where S<:Any
    sidx = findfirst(isequal(s), a.states)
    sidx isa Nothing && return -10^10
    return a.α[sidx]
end
Base.getindex(a::AlphaVector{S}, Ss::AbstractArray) where S<:Any = map(s -> getindex(a,s), Ss)

# dot(env::X, alphavec::AlphaVector, b::D) where {X<:POMDP, D<:Distribution} = dot(env,alphavec,b) # POMDP distributions cannot be typechecked: life is hard...
function dot(alphavec::AlphaVector{<:Any}, b)
    return sum(s -> pdf(b,s) * alphavec[s], support(b))
end

function dot(alphavec::AlphaVector{<:Any}, b::DiscreteHashedBelief{<:Any})
    alphainf = 10^10
    d = 0.0
    aidx, bidx = 1, 1
    n_sup_b = length(b.state_list)
    n_sup_a = length(alphavec.states)
    # println("---")
    # println(alphavec)
    # println(b)
    while bidx <= n_sup_b
        aidx > n_sup_a && return -alphainf
        sa, sb = alphavec.states[aidx], b.state_list[bidx]
        if sb == sa 
            d += b.probs[bidx] * alphavec.α[aidx]
            bidx += 1; aidx += 1
        elseif objectid(sb) < objectid(sa)
            return -alphainf 
        elseif objectid(sa) < objectid(sb)
            aidx += 1
        end
    end
    # println(d)
    return d 
end

function support_has_overlap(a::AlphaVector, S)
    for s in S
        if !(findfirst(isequal(s), a.states) isa Nothing)
            return true
        end
    end
    return false
end

function support_is_subset(a1::AlphaVector, a2::AlphaVector)
    a1_subset = true
    n = 0
    for s in a1.states
        if findfirst(isequal(s), a2.states) isa Nothing
            a1_subset = false 
        else
            n+=1
        end
    end
    return (length(a1.states) == n, a1_subset)
end
            
function beliefspace_dominant(a1::AlphaVector, a2::AlphaVector, B; delta=0.01)
    # Condition 1: different actions (we want to keep these for robustness)
    a1.action !== a2.action && (return false, false)

    # Condition 2: domination must happen for all states in the support
    a1_dominant, a2_dominant = support_is_subset(a1,a2)

    # Condition 3: belief-wise domination (SARSOP)

    for b in B
        !a1_dominant && !a2_dominant && (return (false,false))
        sumsquared, dot_sum = 0.0, 0.0
        for s in support(b)
            if s in a1.states
                diff = a1[s] - a2[s]
                sumsquared += abs2(diff)
                dot_sum += diff*pdf(b,s)
            end
        end
        dV = dot_sum / sqrt(sumsquared)
        dV <= delta && (a1_dominant = false)
        dV >= -delta && (a2_dominant = false)
    end
    return a1_dominant, a2_dominant
end

function beliefspace_dominant(a1::AlphaVector, a2::AlphaVector, B::AbstractVector{DiscreteHashedBelief{<:Any}}; delta=0.01)
    # Condition 1: different actions (we want to keep these for robustness)
    a1.action !== a2.action && (return false, false)

    # Condition 2: domination must happen for all states in the support
    a1_dominant, a2_dominant = support_is_subset(a1,a2)

    # Condition 3: belief-wise domination (SARSOP)

    n_a1_sup, n_a2_sup = length(a1.states), length(a2.states)

    for b in B
        !a1_dominant && !a2_dominant && (return (false,false))
        bidx, a1idx, a2idx = 1, 1, 1
        sumsquared, dot_sum = 0.0, 0.0
        n_b_sup = length(b.state_list)

        while bidx <= n_b_sup
            a1idx > n_a1_sup && break
            sb, sa1 = b.state_list[bidx], a1.states[a1idx]

            if sb == sa1 
                a2idx <= n_a2_sup ? (sa2 = a2.states[a2idx]) : sa2 = nothing
                if sa1 == sa2 && !(sa2 isa Nothing)
                    diff = a1.α[a1idx] - a2.α[a2idx]
                    sumsquared += abs2(diff)
                    dot_sum += diff * b.probs[bidx]
                    bidx += 1; a1idx += 1; a2idx += 1
                elseif objectid(sa1) < objectid(sa2) || sa2 isa Nothing
                    diff = a1.α[a1idx]
                    sumsquared += abs2(diff)
                    dot_sum += diff * b.probs[bidx]
                elseif objectid(sa2) < objectid(sa1)
                    a2idx += 1
                end

            elseif objectid(sb) < objectid(sa1)
                bidx += 1
            elseif objectid(sa1) < objectid(sb)
                a1idx += 1
            end
        end
        dV = dot_sum / sqrt(sumsquared)
        dV <= delta && (a1_dominant = false)
        dV >= -delta && (a2_dominant = false)       
    end
    return a1_dominant, a2_dominant
end

##################################################################
#                   Alpha Vectors Policies
##################################################################

RANDOMIZE_EPSILON = 1e-5

@kwdef struct RobustAlphaVectorPolicy <: Policy 
    env::X where X<:POMDP           # (R)POMDP model
    alphas::Vector{AlphaVector}     # List of alpha-vectors, assumed sorted according to actions
    aidxs::Vector{Any}              # Indexes of alpha-vectors with a given action (usefull for updates)
    custom_memory_update = nothing
end

# function RobustAlphaVectorPolicy(env::X, alphas::AbstractVector{AlphaVector{S}}) where {X<:POMDP, S<:Any} # Damned Julia Typechecking...
function RobustAlphaVectorPolicy(env, alphas; custom_memory_update=nothing)
    alpha_actions = map(alpha -> actionindex(env, alpha.action), alphas)
    mask = sortperm(alpha_actions)
    alpha_actions = alpha_actions[mask]
    alphas = alphas[mask]

    aidxs = []
    for aidx in map( a-> actionindex(env, a),actions(env))
        start, stop = findfirst(isequal(aidx), alpha_actions), findlast(alpha_actions .== aidx)
        (start isa Nothing) ? (push!(aidxs, [])) : (push!(aidxs, start:stop))
    end
    return RobustAlphaVectorPolicy(env, alphas, aidxs, custom_memory_update)
end

# function get_action_probs_v1(b, alphas, epsilon=0.01)
#     probs = zeros(length(alphas))
#     for s in support(b)
#         val, aidx = findmax(a -> a[s], alphas)
#         aidxs = findall(a->isless(val*(1-epsilon), a[s]), alphas)
#         probs[aidxs] = probs[aidxs] .+ (pdf(b,s) * 1/length(aidxs))
#     end
#     return probs 
# end

function get_action_probs(b, alphas, min_prob=0.0)
    # model = Model(Clp.Optimizer; add_bridges=false)
    # model = Model(Gurobi.Optimizer; add_bridges=false)
    # set_silent(model)
    # set_string_names_on_creation(model, false)
    model = direct_generic_model(Float64, Gurobi.Optimizer(GRB_ENV))
    set_silent(model)
    set_string_names_on_creation(model, false)
    set_optimizer_attribute(model, "PoolSearchMode", 2)
    set_optimizer_attribute(model, "PoolSolutions", 1_000)

    Ss = support(b)
    nmbr_optimal_actions = length(alphas)

    @variable(model, min_prob <= ps[1:nmbr_optimal_actions] <= 1.0)
    @constraint(model, sum(ps) == 1.0)
    @variable(model, Q)
    @variable(model, Qs[1:length(Ss)])
    @variable(model, Qs_err[i=1:length(Ss)] >= 0)
    # this includes Qs_err[i,i], which is dumb, but I can't think of a nicer way to write this...
    @constraint(model, Q == sum(sidx -> pdf(b,Ss[sidx]) * Qs[sidx], 1:length(Ss)))

    for (sidx,s) in enumerate(Ss)
        @constraint(model, Qs[sidx] == sum(aidx -> ps[aidx] * alphas[aidx][Ss[sidx]], 1:nmbr_optimal_actions))
        @constraint(model, Qs_err[sidx] >=  (Qs[sidx] - Q))
        @constraint(model, Qs_err[sidx] >=  (Q - Qs[sidx]))
    end
    @objective(model, Min, sum(Qs_err))
    optimize!(model)

    # Break ties by prefering 'safe' actions, i.e. actions with low variance
    if result_count(model) > 1
        EPSILON_OPTIMALITY = 0.005
        minerror = objective_value(model)
        @constraint(model, sum(Qs_err) <=  minerror*(1+EPSILON_OPTIMALITY))
        variance = []
        for (aidx, a) in enumerate(ps[1:nmbr_optimal_actions])
            push!( variance, sum( s -> pdf(b,s) * alphas[aidx][s], support(b)))
        end
        @objective(model, Max, sum(variance .* ps))
        optimize!(model)
    end

    probs = JuMP.value.(ps)
    # println(b)
    # println(probs)
    # println(alphas)
    # println("---")
    return probs
end

# action_value(π::RobustAlphaVectorPolicy, b::X; randomize_epsilon=0.01) where X<:Distribution = action_value(π, b; randomize_epsilon=randomize_epsilon)
@memoize LRU(maxsize=100_000) function action_value_distr(π::RobustAlphaVectorPolicy, b; randomize_epsilon=RANDOMIZE_EPSILON)
    env = π.env
    na = length(actions(env))
    best_values, best_alphas = zeros(na) .- Inf, Vector(undef, na)
    for (alphaidx, alpha) in enumerate(π.alphas)
        this_value = dot(alpha, b)
        aidx = actionindex(env, alpha.action)
        if best_values[aidx] < this_value
            best_values[aidx] = this_value
            best_alphas[aidx] = alpha 
        end         
    end
    val, best_aidx = findmax(best_values)
    length(support(b)) == 1 && (return SparseCat([actions(env)[best_aidx]], [1.0]), val)
    mask = map(aidx -> best_values[aidx] >= val - abs(val*randomize_epsilon), 1:na)
    best_alphas = best_alphas[mask]
    if length(best_alphas) > 1
        probs = get_action_probs(b, best_alphas)
    else
        probs = [1.0]
    end
    isempty(probs) && println("Error: no optimal action picked for belief $b!")
    return SparseCat(actions(env)[mask], probs), val
end

function action_value(π::RobustAlphaVectorPolicy, b; randomize_epsilon=RANDOMIZE_EPSILON)
    A_dist, V = action_value_distr(π,b; randomize_epsilon=randomize_epsilon)
    a = rand(A_dist)
    return a, V
end
function action_distr(π::RobustAlphaVectorPolicy, b; randomize_epsilon=RANDOMIZE_EPSILON)
    A_distr, V = action_value_distr(π,b; randomize_epsilon=randomize_epsilon)
    
    return A_distr
end

POMDPs.action(π::RobustAlphaVectorPolicy,b) = first(action_value(π,b))
POMDPs.value(π::RobustAlphaVectorPolicy,b) = last(action_value(π,b))

get_memory_type(π::RobustAlphaVectorPolicy) = DiscreteHashedBelief{π.constants.S}
function get_initial_memory(π::RobustAlphaVectorPolicy)
    b0 = initialstate(π.env)
    return DiscreteHashedBelief(support(b0), map(s->pdf(b0,s), support(b0)))
end

update_memory(π::RobustAlphaVectorPolicy, b, a, o) = update_memory(π,DiscreteHashedBelief(b),a,o)
@memoize LRU(maxsize=100_000) function update_memory(π::RobustAlphaVectorPolicy, b::DiscreteHashedBelief, a, o)
    !isnothing(π.custom_memory_update) && (return π.custom_memory_update(π,b,a,o))
    Q_, alpha_, Bdistr = backup(π.env,b,a,π.alphas)
    Os, Bos = getindex.(support(Bdistr), 1), getindex.(support(Bdistr),2)
    isempty(Os) && println("Error: impossible observation $o for belief $b and action $a found.")
    idx = findfirst(isequal(o), Os)
    idx isa Nothing && return initialstate(π.env) # Observations with 0 probability are not included in Os, but may be defined.
    return Bos[idx]
end

function get_exterior_values(π::RobustAlphaVectorPolicy)
    Vs = zeros(length(states(π.env)))
    for s in states(π.env)
        Vs[stateindex(π.env, s)] = maximum(alpha -> alpha[s], π.alphas)
    end
    return Vs
end

@memoize LRU(maxsize=10) function get_Qmax(π::RobustAlphaVectorPolicy)
    Qmax = []
    for s in states(π.env)
        push!(Qmax, maximum(alpha -> alpha[s], π.alphas))
    end
    return Qmax
end

@memoize LRU(maxsize=10) function get_Q(π::RobustAlphaVectorPolicy)
    Q = zeros( length(states(π.env)), length(actions(π.env)) )
    for s in states(π.env)
        sidx = stateindex(π.env, s)
        for alpha in π.alphas
            aidx = actionindex(π.env, alpha.action)
            Q[sidx, aidx] = max(Q[sidx, aidx], alpha[s])
        end    
    end
    return Q
end

##################################################################
#                       Robust Backup
##################################################################

function get_relevant_sets(env, b, a)
    Ss, Sps, Os = Set(), Set(), Set()
    for s in support(b)
        push!(Ss, s)
        T = transition(env,s,a)
        for sp in support(T)
            if sup(pdf(T,sp)) > 0.0
                push!(Sps, sp)
                O = observation(env,a,sp)
                for o in support(O)
                pdf(O,o) > 0.0 && push!(Os, o)
                end
            end
        end
    end
    return collect(Ss), collect(Sps), collect(Os)
end

"""
Robust backup for a single belief-action pair, following Osogami (2015)
"""
# backup(env::IPOMDP, b, a, Alphas::Vector{<:AlphaVector}) = backup(env, DiscreteHashedBelief(b), a, Alphas)

function get_error(alpha::AlphaVector, b::DiscreteHashedBelief)
    error = 0.0
    for s in support(b)
        for sp in support(b)
            error += abs(alpha[s] - alpha[sp])
        end
    end
    return error
end

function prob_o_given_s(env, T, a, o)
    p = 0.0
    for sp in support(T)
        p += pdf(T, sp) * pdf(observation(env,a,sp), o)
    end
    return p
end

function is_valid_alpha(env, alpha, a, o, Sps)
    for sp in Sps
        if pdf(observation(env,a,sp), o) > 0.0 
            if isnothing(findfirst(isequal(sp), alpha.states))
                return false
            end
        end
    end
    return true
end

function approximate_belief(b::DiscreteHashedBelief)
    ss, probs = b.state_list, b.probs
    prob_removed = 0.0
    for (sidx, s) in enumerate(ss)
        if probs[sidx] <= 1e-6
            prob_removed += probs[sidx]
            probs[sidx] = 0.0
        end 
    end
    probs = probs ./ (1-prob_removed)
    return DiscreteHashedBelief(ss, probs)
end


function backup(env::IPOMDP, b::DiscreteHashedBelief, a, Alphas::Vector{<:AlphaVector})
    
    # Get relevant variables:
    b = approximate_belief(b)
    Ss, Sps, Os = get_relevant_sets(env,b, a)
    Alphas = Alphas[map(a -> support_has_overlap(a, Sps), Alphas)]

    if isterminalbelief(env,b)
        return 0.0, AlphaVector(zeros(length(Ss)), Ss, a), SparseCat([(observations(env)[1], b)], [1.0])
    end
    # model = Model(Clp.Optimizer; add_bridges=false)
    # model = direct_generic_model(Float64, Gurobi.Optimizer(GRB_ENV))
    # set_silent(model)
    # set_string_names_on_creation(model, false)

    ###
    # 1 : find worst-case probabilities & value
    ###

    model = direct_generic_model(Float64, Gurobi.Optimizer(GRB_ENV))
    set_silent(model)
    # set_string_names_on_creation(model, false)
    set_optimizer_attribute(model, "PoolSearchMode", 2)
    set_optimizer_attribute(model, "PoolSolutions", 1_000)

    # Set up LP (Osogami, Eq. 5)
    @variable(model, Qo[1:length(Os)])  # P(o|b) * Q(bp|o) for each different observation
    @variable(model, 0.0 <= ps[1:length(Ss), 1:length(Os), 1:length(Sps)] <= 1.0)   # P(o,sp|s)
    for (sidx, s) in enumerate(Ss)
        @constraint(model, sum(ps[sidx,:,:]) == 1.0)
    end

    # Constraint 1: ps fall within intervals
    # Constraint 2: observation probabililies are correct
    for (sidx, s) in enumerate(Ss)
        thisT = transition(env,s,a)
        for (spidx, sp) in enumerate(Sps)
            thisInt = pdf(thisT, sp)
            @constraint(model, inf(thisInt) <= sum(ps[sidx, :, spidx]))
            @constraint(model, sum(ps[sidx, :, spidx]) <= sup(thisInt))
            # @constraint(model, eps <= sum(ps[sidx, :, spidx]))
            for (oidx,o) in enumerate(Os)
                prob_o_given_sp = pdf(observation(env,a,sp), o) # Future work: this could also be an interval!
                @constraint(model, ps[sidx,oidx,spidx] == prob_o_given_sp * sum(ps[sidx,:,spidx]))
            end
        end
    end

    # Constraint 3: Qo corresponds to optimal next action
    # ∀α,o: Qo := P(o|b) * Q(bp|o) ≥ ∑ b(s) * ∑ P(sp,o|s) * α[sp]
    for (oidx, o) in enumerate(Os)
        for alpha in Alphas
            if is_valid_alpha(env,alpha,a,o,Sps)
                @constraint(model, Qo[oidx] >=  sum(sidx -> pdf(b,Ss[sidx]) .* sum(spidx -> ps[sidx,oidx,spidx] .* alpha[Sps[spidx]], 1:length(Sps)), 1:length(Ss) ))
            end
        end
    end

    # Solve LP: find worst-case transition
    @objective(model, Min, sum(Qo))
    optimize!(model)

    bestQ = objective_value(model)
    Qos = JuMP.value.(Qo)
    EPSILON_OPTIMALITY = 0.0001
    slackify = (val -> min(val*(1+EPSILON_OPTIMALITY), val* (1-EPSILON_OPTIMALITY)))
    Q_slack = slackify(bestQ)
    Qos_slack = map(Q -> slackify(Q), Qos)

    ###
    # 2 : Precompute relevant probabilities and alpha-vectors for next step
    ###

    Tsplit = JuMP.value.(ps)
    T = zeros(length(Sps))
    for (sidx, s) in enumerate(Ss)
        for (spidx, sp) in enumerate(Sps)
            T[spidx] += pdf(b, s) * sum(Tsplit[sidx,:,spidx])
        end
    end

    slack = zeros(Float64, length(Sps), length(Sps))
    for (sidx, s) in enumerate(Ss)
        this_T = transition(env,s,a)
        for (sp1idx, sp1) in enumerate(Sps)
            pdf(this_T, sp1) == 0.0 && continue 
            for (sp2idx, sp2) in enumerate(Sps)
                (pdf(this_T, sp2) == 0.0) || (sp1idx == sp2idx) && continue 
                slack_sp1 = sum(Tsplit[sidx,:,sp1idx]) - inf(pdf(this_T, sp1))
                slack_sp2 = sup(pdf(this_T, sp2)) - sum(Tsplit[sidx,:,sp2idx])
                slack[sp1idx,sp2idx] = max(slack[sp1idx,sp2idx], min(slack_sp1, slack_sp2, 0.05))
            end
        end
    end
    # println("Slack = $slack")

    bos, bo_probs = [], []
    is_optimal = falses(length(Alphas))
    for (oidx, o) in enumerate(Os)
        bo_vector = map(spidx -> sum(sidx -> pdf(b,Ss[sidx]) * Tsplit[sidx,oidx,spidx], 1:length(Ss)), 1:length(Sps))
        prob = sum(bo_vector)
        bo = DiscreteHashedBelief(Sps, vec(bo_vector) ./ prob)
        push!(bos, (o, bo))
        push!(bo_probs, prob)
        for (aidx, alpha) in enumerate(Alphas)
            thisvalue = 0.0
            for sp in support(bo)
                thisvalue += pdf(bo,sp) * alpha[sp]
            end
            thisvalue * prob >= Qos_slack[oidx] && (is_optimal[aidx] = true)
        end
    end
    Alphas = Alphas[is_optimal]
    # Alphas = Alphas
    # println(Alphas)
    # println(Tsplit)
    # println(T)
    # println(length(Alphas))

    ###
    # 3 : Pick probabilities for \alphas that are optimal for worst-case and after slack
    ###

    model = direct_generic_model(Float64, Gurobi.Optimizer(GRB_ENV))
    set_silent(model)
    @variable(model, 0 <= alpha_probs[1:length(Os), 1:length(Alphas)] <=1)
    for (oidx, o) in enumerate(Os)
        @constraint(model, sum(alpha_probs[oidx, :]) == 1.0)
    end
    @variable(model, alphastar[1:length(Ss)])
    @variable(model, bestQ >= Q >= Q_slack)
    # @variable(model, Q)
    @constraint(model, Q == sum(sidx -> pdf(b,Ss[sidx]) * alphastar[sidx], 1:length(Ss)))

    # Define alphastar
    for (sidx,s) in enumerate(Ss)
        @constraint(model, alphastar[sidx] <= (
            sum( oidx -> sum(spidx -> sum( aidx -> 
                Tsplit[sidx,oidx,spidx] * alpha_probs[oidx,aidx] * Alphas[aidx][Sps[spidx]],
            1:length(Alphas)), 1:length(Sps)), 1:length(Os))
        ))
    end

    # Constraint: alpha-vectors must have optimal value for all slack-points
    for (sp1idx, sp1) in enumerate(Sps)
        for (sp2idx, sp2) in enumerate(Sps)
            thisT = T #TODO: this should represent T-next, now it's the transition function!
            (slack[sp1idx, sp2idx] == 0.0) || (sp1idx == sp2idx) && continue
            thisT[sp1idx] -= slack[sp1idx, sp2idx]
            thisT[sp2idx] += slack[sp1idx, sp2idx]
            # println("Tsplit=$Tsplit")
            # println("thisT = $thisT")
            @constraint(model, Q <= (
                sum(oidx -> sum(spidx -> sum(aidx ->
                    thisT[spidx] * pdf(observation(env,a,Sps[spidx]), Os[oidx]) *
                    alpha_probs[oidx,aidx] * Alphas[aidx][Sps[spidx]],
                1:length(Alphas)), 1:length(Sps)), 1:length(Os))
            ))
        end
    end

    for (oidx, o) in enumerate(Os)
        for (aidx, alpha) in enumerate(Alphas)
            if !is_valid_alpha(env, alpha,a,o,Sps)
                @constraint(model, alpha_probs[oidx,aidx] == 0.0)
            end
        end
    end

    @variable(model, alpha_diff[1:length(Ss)] >= 0.0)
    for (sidx, s) in enumerate(Ss)
        @constraint(model, alpha_diff[sidx] >= pdf(b,s) * (Q - alphastar[sidx]))
    end

    
    # println(model)
    @objective(model, Max, Q - 0.01 * sum(alpha_diff))
    # @objective(model, Max, Q)
    # println(model)
    optimize!(model)
    if termination_status(model) != MOI.OPTIMAL 
        println("Error!")
        println("a=$a, b=$b")
        println("Alphas = $Alphas")
        println("states = $states")
        println(model)
        println("T = $T")
        println("Tsplit = $Tsplit")
    end

    # println(objective_value(model))
    if JuMP.value.(Q) < Q_slack
        println("Q=$(JuMP.value.(Q)), Qslack=$Q_slack")
        println(model)
    end
    # println(model)
    alphastar = discount(env) .* JuMP.value.(alphastar) .+ map(s -> reward(env,s,a), Ss)
    for (sidx, s) in enumerate(Ss)
        pdf(b,s) < 1e-5 && (alphastar[sidx] = min(alphastar[sidx], 0.0))
    end
    alphastar = AlphaVector(alphastar, Ss, a)

    # println(alphastar)

    if dot(alphastar,b) - sum(map(s -> pdf(b,s) * reward(env,s,a), Ss)) > discount(env) * bestQ * 1.005 && bestQ > 0.01
        println("Error: Backup incorrect!")
        println("$a, $b")
        println("$(dot(alphastar,b) - sum(map(s -> pdf(b,s) * reward(env,s,a), Ss))), $bestQ")
        println("$Sps")
        println("$alphastar")
        println(bos)
        println(bo_probs)
        println("!---!")
    end

    return dot(alphastar,b), alphastar, SparseCat(bos, bo_probs)
end

# backup(env::POMDP, b, a, Alphas::Vector{<:AlphaVector}) = backup(env, DiscreteHashedBelief(b), a, Alphas)
function backup(env::POMDP, b::DiscreteHashedBelief, a, Alphas::Vector{<:AlphaVector})
    # Get relevant variables:
    Ss, Sps, Os = get_relevant_sets(env,b, a)
    Alphas = Alphas[map(a -> support_has_overlap(a, Sps), Alphas)]

    if isterminalbelief(env,b)
        return 0.0, AlphaVector(zeros(length(Ss)), Ss, a), SparseCat([(observations(env)[1], b)], [1.0])
    end

    Prob_o_given_sp = [pdf(observation(env,a,sp), o) for (o,sp) in Iterators.product(Os, Sps)]
    Prob_sp_given_s = [pdf(transition(env,s,a), sp) for (sp,s) in Iterators.product(Sps, Ss)]
    
    Bs, B_probs = [], []
    alphastar = map(s -> reward(env,s,a), Ss)
    for (oidx,o) in enumerate(Os)
        bo_vector = map(spidx -> sum(sidx -> pdf(b, Ss[sidx]) * Prob_sp_given_s[spidx,sidx] * Prob_o_given_sp[oidx,spidx], 1:length(Ss)), 1:length(Sps))
        prob = sum(bo_vector)
        bo_vector = bo_vector ./ prob
        bo = DiscreteHashedBelief(Sps, bo_vector)
        push!(Bs, (o,bo))
        push!(B_probs, prob)
        # push!(B_probs, sum( sidx -> sum( spidx -> pdf(b,Ss[sidx]) * Prob_sp_given_s[spidx,sidx] * Prob_o_given_sp[oidx,spidx], 1:length(Sps)), 1:length(Ss)))
        alpha_o =  argmax(alpha -> dot(alpha,bo), Alphas)
        for (sidx, s) in enumerate(Ss)
            alphastar[sidx] += discount(env) * sum(Prob_sp_given_s[:,sidx] .* Prob_o_given_sp[oidx,:] .* alpha_o[Sps])
            # for spidx in (1:length(Sps))[Prob_sp_given_s[:,sidx] .> 0.0]
            #     alphastar[sidx] += discount(env) * Prob_sp_given_s[spidx,sidx] * Prob_o_given_sp[oidx, spidx] * alpha_o[Sps[spidx]]
            # end
        end
    end

    alpha = AlphaVector(alphastar, Ss, a)
    # println(b," ",a, SparseCat(Bs, B_probs))

    return dot(alpha, b), alpha, SparseCat(Bs, B_probs)
end


##################################################################
#                       Other
##################################################################

#TODO: this is incorrect: redo!
struct ZeroAlphas <: Solver end
function POMDPs.solve(solver::ZeroAlphas, env)
    Rmin, aminidx = findmax( a -> (minimum( s -> reward(env, s, a), states(env))), actions(env))
    Vmin = Rmin / (1.0-discount(env))
    alpha_normal = AlphaVector(zeros(length(states(env))) .+ Vmin, collect(states(env)), actions(env)[aminidx])
    s_terminal = []
    for s in states(env)
        isterminal(env,s) && push!(s_terminal,s)
    end
    # alpha_terminal = AlphaVector(zeros(length(s_terminal)), collect(s_terminal), amin)
    # Rmin == 0.0 && (Rmin = 0.01) # Hack to make sure our sparse Vector has no zero elements. Note that this is still a valid lower bound!
    return RobustAlphaVectorPolicy(env, [alpha_normal])
    # return RobustAlphaVectorPolicy(env, [alpha_normal, alpha_terminal])
end
