# Partially observable variant of the Machine Replacement MDP described by Delage and Mannor (2010): https://www.jstor.org/stable/40605970?seq=8
# To add partial observability, we assume the agent only observes it's next state if it pays some measurement cost.

BIGREPAIRSTATE, SMALLREPAIRSTATE = -1, 0
NOOP, REPAIR = 1,2
NOMEASURE, MEASURE = 1,2
PROBS_NOOP = [0.2, 0.8]
PROBS_REP  = [0.3, 0.6, 0.1]

@kwdef struct MachineReplacement <: IPOMDP{Int, Int, Int}
    discount::Float64           = 0.90
    nmbr_states::Int            = 5
    big_repair_cost::Int        = -10.0
    small_repair_cost::Float64  = -2.0
    break_cost::Float64         = -20.0
    measure_cost::Float64       = -1
end

POMDPs.states(M::MachineReplacement) = vcat([BIGREPAIRSTATE,SMALLREPAIRSTATE], 1:M.nmbr_states)
# POMDPs.states(M::MachineReplacement) = vec([(h,d) for h in 1:SINK, d in [HIDDEN, DETECTED]])
POMDPs.statetype(M::MachineReplacement) = Int
POMDPs.stateindex(M::MachineReplacement, s) = findfirst(isequal(s), states(M))
POMDPs.actions(M::MachineReplacement) = vec([(a,m) for a in [NOOP, REPAIR], m in [NOMEASURE, MEASURE]])
POMDPs.actiontype(M::MachineReplacement) = Tuple{Int, Int} 
POMDPs.actionindex(M::MachineReplacement, a) = findfirst(isequal(a), actions(M))
POMDPs.observations(M::MachineReplacement) = states(M)
POMDPs.obstype(M::MachineReplacement) = Int
POMDPs.obsindex(M::MachineReplacement, o) = o
POMDPs.discount(M::MachineReplacement) = M.discount
POMDPs.initialstate(M::MachineReplacement) = SparseCat([1], [1.0])
POMDPs.isterminal(M::MachineReplacement, s) = false

function POMDPs.transition(M::MachineReplacement, s, a)
    act, measure = a
    sp = min(s+1, M.nmbr_states)

    if act == NOOP
        s == BIGREPAIRSTATE && return SafeSparseCat([BIGREPAIRSTATE], [1.0])
        return SafeSparseCat([s,sp], PROBS_NOOP)
    elseif act == REPAIR
        return SafeSparseCat([sp, SMALLREPAIRSTATE, BIGREPAIRSTATE], PROBS_REP)
    end
    println("Warning: no valid transition found for s = $s, a = $a")
end

function POMDPs.observation(M::MachineReplacement, a, sp)
    act, measure = a 
    measure == MEASURE && return SparseCat([sp], [1.0])
    return SparseCat([1], [1.0])
end

function POMDPs.reward(M::MachineReplacement, s, a)
    act, measure = a  
    r = 0.0
    measure == MEASURE && (r += M.measure_cost)
    s == BIGREPAIRSTATE && (r += M.big_repair_cost)
    s == SMALLREPAIRSTATE && (r += M.small_repair_cost)
    s == M.nmbr_states && (r += M.break_cost)
    return r
end