module groups_rl

using Random, Distributions
using LinearAlgebra, SparseArrays, Statistics, StatsBase
using MLDataUtils, Clustering
using JuMP

using Distributed, SharedArrays
using parallel
using Trees, bound_rl, Nodes_rl

export n_groups

function unique_inverse(A::AbstractArray)
    out = Array{eltype(A)}(undef, 0)
    out_idx = Array{Vector{Int}}(undef, 0)
    seen = Dict{eltype(A), Int}()
    for (idx, x) in enumerate(A)
        if !in(x, keys(seen))
            seen[x] = length(seen) + 1
            push!(out, x)
            push!(out_idx, Int[])
        end
        push!(out_idx[seen[x]], idx)
    end
    out, out_idx
end

# function grouping(X, tree, D, ngroups)
#     nz,Tl = size(tree.z)
#     p,n = size(X)
#     if nz != n
#         z = bound.warm_start_z(X, tree.a, tree.b, D)
#     else
#         z = tree.z
#     end
#     return kmeans_group(X, tree.c, z, D, ngroups)
# end


# determine the number of groups
function n_groups(states, V_update, feature_dim, num_states, num_actions, D, warm_start::Tree, method)
    # grouping calculation
    if "SG" in method # SG
        # if D = 2, spl per gp <= 100
        if num_states < 50
            ngroups = 2
        else
            
            if D > 3
                if num_states <= 5000 && num_states >= 1000
                    gp_size=100
                else
                    gp_size = 30
                end
            else
                #gp_size = 50     
                if num_states <= 1500 && num_actions <= 10 && feature_dim <= 10
                    cnst = 319 # 319 for small and medium data and 150 for large data
                else
                    cnst = 150
                end
                # cnst = 150
                gp_size = 1/(2^D+num_actions)*(cnst+feature_dim+2-2^D*((feature_dim+2+num_actions)))
                #gp_size = 1/2^D*(cnst+num_states+2)-(num_states+2+num_actions)
                if gp_size <= 0
                    gp_size = max(30,num_actions)
                end
                
            end
            # parallel.root_println("gp_size: $gp_size")
            ngroups = Int(round(num_states/gp_size)) # Int(round(n/200*2^D)) #
        end
        parallel.root_println("ngroups: $ngroups")
        # groups = grouping(states, rewards, transition_prob, initial_state_p, feature_dim, warm_start, D, ngroups)
        # just divide the data into ngroups by original order
        groups = [Int[] for i=1:ngroups]
        for i in 1:num_states
            push!(groups[i % ngroups + 1], i)
        end
        #println(length.(groups))
    else # if the method is LD or closed-form, then grouping is based on single samples
        ngroups = n
        groups = [[i] for i=1:ngroups]
    end
    return ngroups, groups
end

# function group_distribute(X_gp::Union{Nothing, Vector{Matrix{Float64}}}, Y_gp::Union{Nothing, Vector{Matrix{Float64}}}, ngroups_all::Int64)
function group_distribute(states_gp, Q_value_gp, ngroups_all)
    # get partitionlist
    parallel.partition_concat(ngroups_all)
    gp_list = parallel.getpartition()
    # spread X to each process according to gp_list
    states_gp = parallel.spread(states_gp)
    Q_value_gp = parallel.spread(Q_value_gp)
    # if X_gp is nothing, then this process is not used in parallel computing, so we set X_proc to nothing.
    if length(states_gp) == 0
        states_proc = Matrix{Float64}[]
        Q_value_proc = Matrix{Float64}[]
    else
        states_proc = view(reduce(vcat, states_gp), :,:)
        Q_value_proc = view(reduce(vcat, Q_value_gp), :,:)
    end
    
    return states_gp, Q_value_gp, states_proc, Q_value_proc, gp_list
end


# divide states, rewards, transition_probs, initial_state_p into groups
function group_generation(states, Q_value, feature_dim, num_states, num_actions, D, warm_start, method)
    if parallel.is_root()
        #if "SG" in mtd
        ngroups_all, groups_all = n_groups(states, Q_value, feature_dim, num_states, num_actions, D, warm_start, method)
        #println("groups_all: $groups_all")
        # generate group based X
        # X_gp = Matrix{Float64}[]
        # Y_gp = Matrix{Float64}[]
        states_gp = Matrix{Float64}[]
        Q_value_gp = Matrix{Float64}[]
        println("start!")
        for i in eachindex(groups_all)
            group = groups_all[i]
            # println("group: $group")
            push!(states_gp, states[group,:])
            push!(Q_value_gp, Q_value[group,:])
        end
    else
        ngroups_all = nothing
        groups_all = nothing
        states_gp = nothing
        Q_value_gp = nothing
    end
    ngroups_all = parallel.bcast(ngroups_all)
    groups_all = parallel.bcast(groups_all)
    states_gp = parallel.bcast(states_gp)
    Q_value_gp = parallel.bcast(Q_value_gp)
    return states_gp, Q_value_gp, ngroups_all, groups_all
end

function groups_on_proc(groups_all::Vector{Vector{Int64}}, gp_list::Vector{Int64})
    gp_length = map(x->length(x), groups_all[gp_list]) # group index for each process
    gp_accu = accumulate(+, gp_length)
    pushfirst!(gp_accu, 0)
    groups = UnitRange{Int64}[]
    for i in 1:(length(gp_accu)-1)::Int64
        push!(groups, (gp_accu[i]+1):gp_accu[i+1])
    end
    return groups
end

function group_trees_init(warm_start::Tree, ngroups::Int64)
    group_trees = Tree[]
    for i in 1:ngroups::Int64
        t = Tree(warm_start.a, warm_start.b, warm_start.c, warm_start.d, nothing, warm_start.D)
        push!(group_trees, t)
    end
    return group_trees
end

function group_generation_rand(states, Q_value, n_all, iter)
    if parallel.is_root()
        states_rand = Matrix{Float64}[]
        Q_value_rand = Matrix{Float64}[]
        ng_all_rand = parallel.nprocs() <= 3 ? 6 : parallel.nprocs() # if nprocs <= 3, then totally 6 groups are generated.
        gp_all_rand = Vector{Int64}[]
        for i in 1:ng_all_rand::Int
            Random.seed!(i*(iter+1)) # here the group sample are selected only for data in each core
            group = sample(1:n_all, 50, replace = true) # true for bootstraping, false for exclusive selection
            group = unique(group)
            push!(gp_all_rand, group)
            push!(states_rand, states[group, :])
            push!(Q_value_rand, Q_value[group, :])
        end
    else
        states_rand = nothing
        Q_value_rand = nothing
        ng_all_rand = nothing
        gp_all_rand = nothing
    end
    ng_all_rand = parallel.bcast(ng_all_rand)
    gp_all_rand = parallel.bcast(gp_all_rand)
    return states_rand, Q_value_rand, ng_all_rand, gp_all_rand
end


function proc_data_preparation(states, Q_value, feature_dim, num_states, num_actions, D, warm_start, method, rand::Bool=false, iter::Int64=0, val::Int64=0)
    # get group data    
    if parallel.is_root()
        max_Q = maximum(Q_value)
        # println("max_Q: ", max_Q)
    else
        max_Q = nothing
    end
    max_Q = parallel.bcast(max_Q)
    
    if !rand
        states_gp, Q_value_gp, ngroups_all, groups_all = group_generation(states, Q_value, feature_dim, num_states, num_actions, D, warm_start, method)
    else
        states_gp, Q_value_gp, ngroups_all, groups_all = group_generation_rand(states, Q_value, num_states, iter)
    end
    states_gp, Q_value_gp, states_proc, Q_value_proc, gp_list = group_distribute(states_gp, Q_value_gp, ngroups_all)
    # the initial best tree on each process only have z for data on each process
    if !rand
        # println("states_proc: ", size(states_proc))
        z_proc = bound_rl.warm_start_z(states_proc, warm_start.a, warm_start.b, D)
        LB_tree = Tree(warm_start.a, warm_start.b, warm_start.c, warm_start.d, z_proc, D)
    else # rand data generation for ub selection does not need to update best tree
        LB_tree = nothing
    end
    # length of samples on each process
    if length(states_proc) == 0
        n = 0
    else
        n = size(states_proc)[1] # mapreduce(x->length(x), +, groups_all[gp_list]) # length(reduce(vcat, groups_all[gp_list]))
    end
    ngroups = length(gp_list)
    # get group index for the sub-dataset in each process
    groups = groups_on_proc(groups_all, gp_list)
    # get group_trees
    group_trees = group_trees_init(warm_start, ngroups)
    # add dummy group info to balance each core, used in Dual Bound computation
    max_l_len = parallel.get_max_list_length()
    if ngroups < max_l_len
        for i in 1:max_l_len-ngroups
            push!(groups, 0:-1) # dummy group with length(group) = 0
            push!(group_trees, Tree()) # dummy tree with tree.D = 0
            push!(Q_value_gp, ones(num_states, num_actions)*max_Q) # dummy Q_value with size = (num_states, num_actions)
        end
    end
    # UB_gp = ones(max_l_len) * sum(maximum(Q_value, dims=2)) # sum of max Q_value for each sample
    # println("ngroups: ", ngroups)
    # println("length of Q_value_gp: ", length(Q_value_gp))
    # println("Q_value_gp[$i]:", (size(Q_value_gp[i])))
    UB_gp = [sum(vec(maximum(Q_value_gp[i], dims=2))) for i in 1:max_l_len]
    # if parallel.myid() == 0
    #     println("UB_gp: ", UB_gp)
    # end
    # parallel.barrier()
    # if parallel.myid() == 1
    #     println("UB_gp: ", UB_gp)
    # end
    # parallel.barrier()
    lrg_gap = falses(max_l_len)
    lower, upper = bound_rl.init_bound(feature_dim, n, num_actions, D, nothing, nothing, val) # lower and upper are in Tree type # here z only for data of each process
    # this node is the root branching node, node level start from 1 for debugging and change back to 0 after debug
    # node = Node(lower, upper, 1, 1e15, -ones(n)*10000, groups, nothing, group_trees, UB_gp, lrg_gap, "d"); 
    node = Node(lower, upper, 1, 1e15, -ones(n)*typemax(Float64), groups, nothing, group_trees, UB_gp, lrg_gap, "d"); 
    return states_gp, Q_value_gp, states_proc, Q_value_proc, node, LB_tree
end

end