# module RPOMDP_Toy1
# using RPOMDPs
using POMDPs, Distributions, IntervalArithmetic
export Machine 
mininterval = interval(0.001,0.999)


@kwdef mutable struct Machine <: IPOMDP{String, Int, Int} 
    discount::Float64               = 0.95
    breakchance::Interval           = interval(0.9,0.9)
end

# States
POMDPs.states(M::Machine) = ["x","y","nx","ny","bx","by","nb","sink"]
POMDPs.statetype(M::Machine) = String
POMDPs.stateindex(M::Machine, s) = findfirst(states(M) .== s)
POMDPs.actions(M::Machine) = 1:2 # Go, repair, flip
POMDPs.actiontype(M::Machine) = Int
POMDPs.actionindex(M::Machine, a) = a
POMDPs.observations(M::Machine) = 1:2 #x, y, z
POMDPs.obstype(M::Machine) = Int
POMDPs.obsindex(M::Machine, o) = o
POMDPs.discount(M::Machine) = M.discount
POMDPs.initialstate(M::Machine) = SparseCat(["nx", "ny"], [0.5, 0.5])
POMDPs.isterminal(M::Machine, s) = s=="sink"


function POMDPs.transition(M::Machine, s,a)
    # (a == 2 || s=="sink") && return SparseICat(["sink"], [interval(1.0)])
    if a == 2
        s == "x" && return SparseICat(["y"], [interval(1.0)])
        s == "y" && return SparseICat(["x"], [interval(1.0)])
        s == "bx" && return SparseICat(["by"], [interval(1.0)])
        s == "by" && return SparseICat(["bx"], [interval(1.0)])
        return SparseICat(["x", "y"], [0.5, 0.5])
    end
    s == "sink" && return SparseICat(["sink"], [interval(1.0)])
    s == "x" && return SparseICat(["nx"], [interval(1.0)])
    s == "nx" && return SparseICat(["x","bx"], [M.breakchance, mininterval])
    s == "y" && return SparseICat(["ny"], [interval(1.0)])
    s == "ny" && return SparseICat(["y","by"], [M.breakchance, mininterval])
    s == "nb" && return SparseICat(["bx", "by"], [mininterval, mininterval])
    s in ["bx","by"] && return SparseICat(["nb"], [interval(1.0)])
    println("transition not recognized! (s=$s, a=$a)")
end

function POMDPs.observation(M::Machine, a, sp)
    (sp=="x" || sp=="bx") && return Deterministic(1)
    (sp=="y" || sp=="by") && return Deterministic(2)
    sp in ["sink", "nx", "ny", "nb"] && return Deterministic(1)
    println("observation not recognized! (sp=$sp, a=$a)")
    # (sp=="z" || sp=="bz") && return Deterministic(3)
end

POMDPs.reward(M::Machine, s, a, sp) = reward(M,s,a)
function POMDPs.reward(M::Machine, s, a)::Float64
    if a == 2
        s in ["nx", "ny", "nb"] ? (return -1.0) : (return 0.0)
    end
    s in ["x","y"] && return 1.0
    s in ["bx","by"] && return 0.0
    s == "sink" && return 0.0
    # s in ["bx", "by"] && a==2 && return 100.0
    # s in ["bx", "by"] && return -1.0
    # return 0.0
    # return 1.0
end



# end