using Random: GLOBAL_RNG

abstract type MDP end

struct ABMDP <: MDP
    aC :: Float64
end

struct SAR{Ts, Ta}
    t::Int64
    s::Ts
    a::Ta
    r::Float64
end

struct SARSA{Ts, Ta}
    t::Int64
    s::Ts
    a::Ta
    r::Float64
    snew::Ts
    anew::Ta
end

function SARSA(sarsar::Tuple{SAR{Ts, Ta}, SAR{Ts, Ta}})::SARSA{Ts, Ta} where {Ts, Ta}
    SARSA(sarsar[1].t, sarsar[1].s, sarsar[1].a,
          sarsar[2].r, sarsar[2].s, sarsar[2].a)
end

function mdp_step(rng::AbstractRNG, mdp::MDP, sar::SAR{Ts, Ta})::SAR{Ts, Ta} where {Ts, Ta}
    @unpack t, s, a = sar
    snew = sample_state(rng, mdp, s, a)
    anew = sample_action(rng, mdp, snew)
    r = sample_reward(mdp, s, a, snew)
    SAR(t + 1, snew, anew, r)
end

function log_time_checker(; r=1.1, n=1000)
    summ_ts = round.(collect(r .^ (1:n))) |> unique
    s -> s.t in summ_ts
end

simulate_mdp(rng::AbstractRNG, mdp::MDP, s0::Ts) where Ts =
    Iterated(sar -> mdp_step(rng, mdp, sar), SAR(0, s0, 1, 0.)) |>
    Drop(1) |>
    Consecutive(2; step=1) |>
    Map(SARSA)
simulate_mdp(mdp::MDP, s0::Int64) = simulate_mdp(GLOBAL_RNG, mdp, s0)

struct TSRBirthDeathMDP <: MDP
    N::Int64
    aL::Float64 # Prob of listing being treated
    aC::Float64 # Prob of customer being treated
    μ::Float64
    λ::Float64
    vs::Vector{Float64}
end

ntr(mdp::TSRBirthDeathMDP) = mdp.N * mdp.aL |> floor |> Int
num_states(mdp::TSRBirthDeathMDP) = (ntr(mdp) + 1) * (mdp.N - ntr(mdp) + 1)
is_valid_state(mdp::TSRBirthDeathMDP, state::Tuple{Int64, Int64}) =
    all(state .>= 0) &
    (state[1] <= mdp.N - ntr(mdp)) &
    (state[2] <= ntr(mdp))
id2state(mdp::TSRBirthDeathMDP, id::Int64) =
    (cld(id, ntr(mdp) + 1) - 1, rem(id - 1, ntr(mdp) + 1, RoundDown))
state2id(mdp::TSRBirthDeathMDP, state::Tuple{Int64, Int64}) =
    if is_valid_state(mdp, state)
        state[1] * (ntr(mdp) + 1) + state[2] + 1
    else
        throw(DomainError(state, "Invalid state for MDP."))
    end

sample_action(rng::AbstractRNG, mdp::TSRBirthDeathMDP, s::Int64) =
    if rand(rng, 1)[1] <= mdp.aC 2 else 1 end

function scenario_tree(mdp::TSRBirthDeathMDP, s::Tuple{Int64, Int64},
                       a::Int64)
    @unpack N, aL, μ, λ, vs = mdp
    Sco, Str = s
    Ntr = ntr(mdp)
    Nco = N - Ntr
    leaf(δ) = LeafNode(clamp.(s .+ δ, (0, 0), (Nco, Ntr)))
    customer_arrival = ScenarioNode(
        (0, 0),
        leaf.([(-1, 0), (0, -1), (0, 0)]),
        [Sco * vs[1], Str * vs[a], mdp.N])
    replenishment = ScenarioNode(
        (0, 0),
        leaf.([(1, 0), (0, 1)]),
        Float64[Nco - Sco, Ntr - Str])
    ScenarioNode(
        (0, 0),
        [replenishment, customer_arrival, leaf((0, 0))],
        [(Nco - Sco + Ntr - Str) * μ, N * λ, (Sco + Str) * μ])
end

sample_state(rng::AbstractRNG, mdp::TSRBirthDeathMDP, s::Tuple{Int64, Int64}, a::Int64) =
    sample_scenario(rng, scenario_tree(mdp, s, a))
sample_state(rng::AbstractRNG, mdp::TSRBirthDeathMDP, s::Int64, a::Int64) =
    state2id(mdp, sample_state(rng, mdp, id2state(mdp, s), a))
sample_reward(mdp::TSRBirthDeathMDP, s::Int64, a::Int64, snew::Int64)::Float64 =
    sum(max.(id2state(mdp, s) .- id2state(mdp, snew), 0))

struct TSRState
    t::Int64
    Ns::MMatrix{2, 2}
end

TSRState() = TSRState(0, zeros(MMatrix{2, 2}))
