using LogicCircuits
using ProbabilisticCircuits
using StatsFuns: logsubexp
using CUDA
using LoopVectorization: @avx


import ProbabilisticCircuits: backprop_flows_down_gpu, apply_gradients_gpu

function backprop_flows_down_gpu(pbc::CatParamBitCircuit, data::CuMatrix, log_grads::CuVector, values::CuMatrix,
                                 param_grads::CuVector, leaf_param_grads::CuMatrix, reuse = nothing; 
                                 dec_per_thread = 8, log2_threads_per_block = 7)
    flows = similar!(reuse, typeof(values), size(values)...)
    
    bc::CatBitCircuit = pbc.bitcircuit
    @inbounds @views param_grads .= zero(eltype(param_grads))
    @inbounds @views leaf_param_grads .= zero(eltype(leaf_param_grads))
    param_grads_device = CUDA.cudaconvert(param_grads)
    leaf_param_grads_device = CUDA.cudaconvert(leaf_param_grads)
    
    CUDA.@sync for layer in Iterators.reverse(bc.layers)
        num_examples = size(values, 1)
        num_decision_sets = length(layer)/dec_per_thread
        num_threads =  balance_threads(num_examples, num_decision_sets, log2_threads_per_block)
        num_blocks = (ceil(Int, num_examples/num_threads[1]), 
                      ceil(Int, num_decision_sets/num_threads[2])) 
        @cuda threads=num_threads blocks=num_blocks backprop_flows_down_layers_cuda(layer, data, bc.nodes, bc.elements, bc.parents, pbc.params, log_grads, flows, values, param_grads_device, leaf_param_grads_device, pbc.lit_to_var)
    end
    
    flows, param_grads, leaf_param_grads
end

function backprop_flows_down_layers_cuda(layer, data, nodes, elements, parents, params, log_grads, 
        flows, values, param_grads, leaf_param_grads, lit_to_var)
    index_x = (blockIdx().x - 1) * blockDim().x + threadIdx().x
    index_y = (blockIdx().y - 1) * blockDim().y + threadIdx().y
    stride_x = blockDim().x * gridDim().x
    stride_y = blockDim().y * gridDim().y
    
    for k = index_x : stride_x : size(values, 1)
        for i = index_y : stride_y : length(layer)
            dec_id = @inbounds layer[i]
            if dec_id == size(nodes, 2)
                # Assign (log) gradient to the root node
                flow = log_grads[k]
            else
                par_start = @inbounds nodes[3, dec_id]
                flow = typemin(eltype(flows)) # log(0)
                if !iszero(par_start)
                    par_end = @inbounds nodes[4, dec_id]
                    for j = par_start : par_end
                        par = @inbounds parents[j]
                        grandpa = @inbounds elements[1, par]
                        sib_id = sibling(elements, par, dec_id)
                        
                        g_flow = @inbounds flows[k, grandpa]
                        d_value = @inbounds values[k, dec_id]
                        s_value = @inbounds values[k, sib_id]
                        if has_single_child(nodes, grandpa)
                            edge_flow = g_flow + s_value
                        else
                            θ = eltype(flows)(params[par])
                            edge_flow = g_flow + s_value + θ
                        end
                        flow = logsumexp_cuda(flow, edge_flow)
                        
                        # Compute gradient only once
                        if sib_id > dec_id
                            grad::Float64 = CUDA.exp(g_flow + d_value + s_value)
                            CUDA.@atomic param_grads[par] += grad
                        end
                    end
                end
            end
            @inbounds flows[k, dec_id] = flow
            if dec_id > 2 && dec_id <= 2 + length(lit_to_var)
                idx_d = dec_id - 2
                idx_d = @inbounds lit_to_var[idx_d]
                d = @inbounds data[k, idx_d]
                leaf_grad::Float64 = CUDA.exp(flow)
                CUDA.@atomic leaf_param_grads[idx_d, d] += leaf_grad
            end
        end
    end
    return nothing
end

function apply_gradients_gpu(pbc::CatParamBitCircuit, param_grads::CuVector, leaf_param_grads::CuMatrix; lr::Float64 = 0.01)
    bc::CatBitCircuit = pbc.bitcircuit
    
    CUDA.@sync for layer in Iterators.reverse(bc.layers)
        num_threads = 2^min(ceil(Int, 2.0 * log2(length(layer))), 8)
        num_blocks = 2^ceil(Int, log2(length(layer)^2 / num_threads))
        @cuda threads=num_threads blocks=num_blocks apply_gradients_cuda(layer, bc.nodes, param_grads, 
            pbc.params, lr::Float64)
    end
    
    @inbounds @views leaf_param_grads .-= logsumexp(leaf_param_grads; dims = 2)
    
    @inbounds @views pbc.cat_params .+= log(1.0 - lr)
    @inbounds @views leaf_param_grads .+= log(lr)
    delta = @inbounds @views @. CUDA.ifelse(pbc.cat_params == leaf_param_grads, CUDA.zero(leaf_param_grads), CUDA.abs(
        pbc.cat_params - leaf_param_grads))
    @inbounds @views @. pbc.cat_params = CUDA.max(pbc.cat_params, leaf_param_grads) + CUDA.log1p(CUDA.exp(-delta))
    pbc.cat_params .-= logsumexp(pbc.cat_params; dims = 2) # normalize away any leftover error
end

function apply_gradients_cuda(layer, nodes, param_grads, params, lr::Float64)
    index_x = (blockIdx().x - 1) * blockDim().x + threadIdx().x
    stride_x = blockDim().x * gridDim().x
    
    log_m_lr = CUDA.log(1.0 - lr)
    log_lr = CUDA.log(lr)
    for i = index_x : stride_x : length(layer)
        dec_id = @inbounds layer[i]
        single_child = has_single_child(nodes, dec_id)
        if !single_child
            ele_start_id = nodes[1, dec_id]
            ele_end_id = nodes[2, dec_id]

            sum_grads = zero(eltype(param_grads))
            @inbounds for ele_id = ele_start_id : ele_end_id
                sum_grads += param_grads[ele_id]
            end
            sum_params = -Inf
            @inbounds for ele_id = ele_start_id : ele_end_id
                params[ele_id] = logsumexp_cuda(params[ele_id] + log_m_lr, log(param_grads[ele_id] / (sum_grads + 1e-8)) + log_lr)
                sum_params = logsumexp_cuda(sum_params, params[ele_id])
            end
            @inbounds for ele_id = ele_start_id : ele_end_id
                params[ele_id] -= sum_params
            end
        end
    end
end