using POMDPs

using QuickPOMDPs
using POMDPTools
# import POMDPModelTools

using NativeSARSOP
using QMDP

import Base: *
using Distributions
using Statistics
using ProgressMeter


const Cell = Tuple{Int,Int}

mutable struct Grid
    lines::Array{Array{Char}}
    xmin::Int64
    xmax::Int64
    ymin::Int64
    ymax::Int64

    function Grid(grid_string::String)
        lines = [collect(line) for line in reverse(split(grid_string, "\n", keepempty=false))]
        xmin = ymin = 1
        ymax = length(lines)
        xmax = length(lines[1])
        for line in lines
            @assert(length(line) == xmax)
        end
        new(lines,xmin,xmax,ymin,ymax)
    end
end

function collect_cells(grid::Grid)
    cells = Vector{Cell}()
    for (y,line) in enumerate(grid.lines)
        for (x,cell) in enumerate(line)
            push!(cells, (x,y))
        end
    end
    return cells
end

function collect_cells(grid::Grid, type::Char)
    cells = Vector{Cell}()
    for (y,line) in enumerate(grid.lines)
        for (x,cell) in enumerate(line)
            if cell == type
                push!(cells, (x,y))
            end
        end
    end
    return cells
end

function clamp(grid::Grid, cell::Cell)
    x,y = cell
    x = Base.clamp(x,grid.xmin,grid.xmax)
    y = Base.clamp(y,grid.ymin,grid.ymax)
    return (x,y)
end

function move(grid::Grid, agent::Cell, direction::String)
    x,y = agent
    if direction == "0"
        y = y+1
    elseif direction == "1"
        y = y-1
    elseif direction == "2"
        x = x+1
    elseif direction == "3"
        x = x-1
    else
        nothing
    end
    agent = (x,y)
    return clamp(grid,agent)
end


function grid_dx(a::Cell,b::Cell)
    return abs(a[1]-b[1])
end

function grid_dy(a::Cell,b::Cell)
    return abs(a[2]-b[2])
end

function grid_see(agent::Cell,cell::Cell,radius=1)
    return grid_dx(agent,cell) <= radius && grid_dy(agent,cell) <= radius
end

function cartesian_product_tuple(a::Vector{<:Tuple}, b::Vector{<:Tuple})
    return [ (x...,y...) for x in a for y in b ]
end


function *(d1::POMDPTools.SparseCat{<:Vector{<:Any},<:Any}, d2::POMDPTools.SparseCat{<:Vector{<:Any},<:Any})
# function product_distribution(d1, d2)
    joint_support = []
    joint_probs = []
    for x1 in POMDPs.support(d1)
        for x2 in POMDPs.support(d2)
            push!(joint_support, (x1...,x2...))
            prob = POMDPs.pdf(d1,x1) * POMDPs.pdf(d2,x2)
            push!(joint_probs, prob)
        end
    end
    return POMDPTools.SparseCat(joint_support, joint_probs)
end

function SparseDeterministic(x::Any)
    return POMDPTools.SparseCat([x], [1])
end

function SparseBernoulli(x::Any, y::Any, prob::Float64)
    return POMDPTools.SparseCat([x,y], [prob,1-prob])
end


function to_pomdp(grid::Grid; apply_missingness::Bool = false)
    init_cells = collect_cells(grid, 'i')
    goal_cells = collect_cells(grid, 'g')
    trap_cells = collect_cells(grid, 't')
    mud_cells = collect_cells(grid, 'm')

    goal_sink = (0,0) # unique goal sink
    trap_sink = (0,1) # unique trap sink

    prob_slip = 1/1000 # prob of slipping regardless of the action
    direction_slip = "←"
    prob_mud_dries = 1/4 # prob of mud changing state
    prob_mud_blank = 99/100 # prob of mud observation missing

    grid_states = collect_cells(grid)
    push!(grid_states, goal_sink)
    push!(grid_states, trap_sink)
    grid_observations = grid_states # -1
    init_states = init_cells
    
    mud_states = [(0,), (1,)] # 0 = no mud, 1 = mud
    mud_init = [ (1,) ]
    mud_observations = [mud_states..., (2,), (-1,)] # 2 = can't see mud, -1 = blank

    for _ in enumerate(mud_cells)
        grid_states = cartesian_product_tuple(grid_states,mud_states)
        init_states = cartesian_product_tuple(init_states, mud_init)
        grid_observations = cartesian_product_tuple(grid_observations,mud_observations)
    end

    directions = ["0","1","2","3"]
    actions = directions
    # actions = [directions..., "⋅"]

    transition_function = Dict()
    for state in grid_states
        for a in actions
            x,y,muds... = state
            agent = (x,y)
            distr = nothing
            if agent in trap_cells
                distr = SparseDeterministic(trap_sink)
            elseif agent in goal_cells
                distr = SparseDeterministic(goal_sink)
            else
                agent_slip = move(grid,agent,direction_slip)
                agent_intended = agent
                if a in directions
                    agent_intended = move(grid,agent,a)
                end
                distr = SparseBernoulli(agent_slip,agent_intended,prob_slip)
            end

            for (index,mud_cell) in enumerate(mud_cells)
                m = muds[index]
                mud_distr = SparseBernoulli(1-m,m,prob_mud_dries)
                distr = distr * mud_distr
            end
            cumsum = sum(s -> pdf(distr,s), support(distr))
            if cumsum != 1.0
                # println("help! Distribution sums only to $cumsum")
                Ss, probs = [], []
                for (idx, s) in enumerate(support(distr))
                    push!(Ss, s)
                    push!(probs, pdf(distr,s) / cumsum)
                end
                distr = SparseCat(Ss, probs)
            end
            transition_function[ (state,a) ] = distr
        end
    end
    # println(transition_function)

    observation_function = Dict()
    for a in actions
        for sp in grid_states
            observation_function[ (a,sp) ] = SparseCat([sp], [1.0])
            x,y,muds... = sp
            agent = (x,y)
            distr = SparseDeterministic( agent )
            for (index,mud_cell) in enumerate(mud_cells)
                m = muds[index]
                mud_distr = nothing
                if !grid_see( agent,mud_cell )
                    mud_distr = SparseDeterministic(2)
                elseif !apply_missingness
                    mud_distr = SparseDeterministic(m)
                else
                    mud_distr = SparseBernoulli(-1,m,prob_mud_blank)
                end
                distr = distr * mud_distr
            end
            cumsum = sum(s -> pdf(distr,s), support(distr))
            if cumsum != 1.0
                # println("help! Distribution sums only to $cumsum")
                Ss, probs = [], []
                for (idx, s) in enumerate(support(distr))
                    push!(Ss, s)
                    push!(probs, pdf(distr,s) / cumsum)
                end
                distr = SparseCat(Ss, probs)
            end
            observation_function[ (a,sp) ] = distr
        end
    end

    reward_function = Dict()
    for s in grid_states
        for a in actions
            x,y,muds... = s
            agent = (x,y)
            reward = 0
            if agent in trap_cells
                reward += -1000
            end
            if agent in goal_cells
                reward += +100
            end
            for (index,mud_cell) in enumerate(mud_cells)
                if agent == mud_cell && muds[index] == 1
                    reward += -100
                end
            end
            reward_function[ (s,a) ] = reward
        end
    end


    pomdp = QuickPOMDPs.QuickPOMDP(
        # statetype = Tuple{Int64,Int64},
        states = grid_states,
        initialstate = POMDPTools.Uniform(init_states),
        # initialstate = POMDPTools.SparseCat([(1,1)], [1.0]),

        actions = actions,
        discount = 0.98,

        transition = function(s,a::String)
            return transition_function[ (s,a) ]
        end,

        # obstype = Tuple{Int64,Int64},
        observations = grid_observations,
        observation = function(a::String, sp)
            return observation_function[ (a,sp) ]
        end,

        reward = function(s,a::String)
            return reward_function[ (s,a) ]
        end,

        isterminal = function(s)
            x,y,muds... = s
            agent = (x,y)
            return agent in [goal_sink,trap_sink]
        end
    )

    # pomdp = POMDPModelTools.SparseTabularPOMDP(pomdp)
    return pomdp
end


function show_policy_qvalue(pomdp, policy, name)
    value = POMDPs.value(policy, beliefvec(pomdp, POMDPs.initialstate(pomdp)))
    for a in policy.alphas
        println(a)
        sum = 0
        b0 = POMDPs.initialstate(pomdp)
        for s in support(b0)
            sum += pdf(b0,s) * a[stateindex(pomdp,s)]
        end
        println(sum)
    end
    println(initialstate(pomdp))
    # println(argmax(a -> dot(beliefvec(pomdp, POMDPs.initialstate(pomdp)), a), policy.alphas))
    println("$name policy Q-value = $value")
end


function mean_and_ci(vector)
    mean = Statistics.mean(vector)
    ci = 1.96 * Statistics.std(vector) / sqrt(length(vector))
    return (mean,ci)
end

function show_policy_simulation_value(pomdp, policy, name)
    println("evaluating $name policy via simulations...")

    # simulator = POMDPTools.RolloutSimulator(eps=1e-2)
    simulator = POMDPTools.RolloutSimulator(max_steps=100)
    rewards = []
    num_simulations = 1000
    progress = ProgressMeter.Progress(num_simulations, barlen=40)
    for sim = 1:num_simulations
        reward = POMDPs.simulate(simulator, pomdp, policy) # updater=
        # println("reward = $reward")
        push!(rewards,reward)
        ProgressMeter.next!(progress)
        # if sim % 100 == 50
        #     mean,ci = mean_and_ci(rewards)
        #     if ci / abs(mean) < 5e-2
        #         println()
        #         break
        #     end
        # end
    end
    mean,ci = mean_and_ci(rewards)

    println("$name policy value = $mean [ci = $ci]")
end


function solve_pomdp(pomdp, solver, name)
    println("solving using $(name)...")
    policy = POMDPs.solve(solver, pomdp)
    show_policy_qvalue(pomdp, policy, name)
    show_policy_simulation_value(pomdp, policy, name)
end

function solve_grid(grid_string::String)
    grid = Grid(grid_string)
    pomdp = to_pomdp(grid)
    println("created a POMDP with $(length(pomdp.data.states)) states")
    println()
    show_policy_simulation_value(pomdp, POMDPTools.RandomPolicy(pomdp), "random")
    solve_pomdp(pomdp, QMDP.QMDPSolver(), "QMDP")
    solve_pomdp(pomdp, NativeSARSOP.SARSOPSolver(), "SARSOP")

end



function main()

    grid1 = """
    .g.
    mt.
    .i.
    """

    grid2 = """
    tmg
    imt
    """

    grids = [grid1,grid2]
    # grids = [grid2]
    # for i=1:2
        for (index,grid) in enumerate(grids)
            println()
            println("----- GRID # $index -----")
            solve_grid(grid)
        end
    # end

end

main()
