using LogicCircuits
using ProbabilisticCircuits
using BenchmarkTools
using Printf: @printf
using DataFrames
using StatsBase: corspearman
using CUDA
using StatsFuns: logsumexp
using DelimitedFiles

include("./utils.jl")


function train_em(pc::ProbCircuit, train_data::DataFrame, valid_data::DataFrame; 
                  train_mode = "normal", gpu_idx = 0,
                  pseudocount = 0.1, soft_reg = 0.002, entropy_reg = 0.0, 
                  use_gpu = true, batch_size = 1024,
                  em_warmup_iters = 100, em_finetune_iters = 100, sgd_iters = 0,
                  exp_update_factor_start = 0.2, exp_update_factor_end = 0.9,
                  sgd_lr_start = 0.01, sgd_lr_end = 0.002,
                  model_save_path::Union{Nothing,String} = nothing,
                  model_save_interval::Integer = 10,
                  data_subsample_frac::AbstractFloat = 1.0, leaf_node_distribution::String = "categorical", 
                  std_max = 1e5, frac = 0.01, n_variables::Integer = 0,
                  log_file_name::Union{Nothing,String} = nothing)
    
    # Specify CUDA device to use
    if use_gpu
        gpu_idx = (gpu_idx == 0) ? find_gpu_device() : gpu_idx
        device!(collect(devices())[gpu_idx])
    end
    
    # Apply soft regularization
    if soft_reg > 1e-8
        train_data = soften(train_data, soft_reg; scale_by_marginal = true)
    end
    
    # Mini-batch size
    batch_size = min(batch_size, Int64(1e9 ÷ num_edges(pc)))
    
    train_lls = Vector{Float64}()
    valid_lls = Vector{Float64}()
        
    for iter = 1 : em_warmup_iters
        
        @printf("      - Warmup iter %d:\n", iter)
        
        if data_subsample_frac < 1.0
            curr_data = bagging_dataset(train_data; num_bags = 1, frac_examples = data_subsample_frac)[1]
            curr_val_data = bagging_dataset(valid_data; num_bags = 1, frac_examples = data_subsample_frac)[1]
        else
            curr_data = train_data
            curr_val_data = valid_data
        end

        # To estimate circuit parameters under the existence of latent/hidden variables, we employ the 
        # Expectation-Maximization (EM) algorithm.
        print("        - Estimating parameters with EM")
        t = @elapsed begin
            exp_update_factor = exp_update_factor_start + (iter - 1) * (exp_update_factor_end - exp_update_factor_start) / (em_warmup_iters - 1)
            if train_mode == "normal"
                estimate_parameters_em(pc, batch(curr_data, batch_size); pseudocount, use_gpu, 
                    exp_update_factor, update_per_batch = true, entropy_reg)
            elseif train_mode == "categorical"
                estimate_parameters_em_cat(pc, batch(curr_data, batch_size); pseudocount, use_gpu, 
                    exp_update_factor, update_per_batch = true, leaf_node_distribution, std_max, frac, n_variables)
            else
                error("Unknown mode: $(train_mode)")
            end
        end
        @printf(" (%.3fs)\n", t)

        # Evaluate
        print("        - Evaluating marginal log-likelihoods")
        t = @elapsed begin
            if train_mode == "normal"
                ll_train = marginal_log_likelihood_avg(pc, batch(curr_data, batch_size); use_gpu)
                ll_valid = marginal_log_likelihood_avg(pc, batch(curr_val_data, batch_size); use_gpu)
            elseif train_mode == "categorical"
                ll_train = marginal_log_likelihood_avg_cat(pc, batch(curr_data, batch_size); use_gpu, n_variables)
                ll_valid = marginal_log_likelihood_avg_cat(pc, batch(curr_val_data, batch_size); use_gpu, n_variables)
                bpd_train = -ll_train * log(MathConstants.e) / log(2) / num_features(train_data)
                bpd_valid = -ll_valid * log(MathConstants.e) / log(2) / num_features(train_data)
            else
                error("Unknown mode: $(train_mode)")
            end
        end
        @printf(" (%.3fs)\n", t)
        @printf("        - Train: \033[1m%.3f\033[0m; Valid: \033[1m%.3f\033[0m\n", 
                ll_train, ll_valid)
        if train_mode == "categorical"
            @printf("        - Train bpd: \033[1m%.3f\033[0m; Valid bpd: \033[1m%.3f\033[0m\n", 
                    bpd_train, bpd_valid)
            if log_file_name !== nothing
                open("logs/" * log_file_name, "a+") do f
                    write(f, "Warmup iteration $(iter) - Train bpd: $(bpd_train); Valid bpd: $(bpd_valid)\n")
                end
            end
        end
        
        push!(train_lls, ll_train)
        push!(valid_lls, ll_valid)
        
        if iter % model_save_interval == 0 && model_save_path !== nothing
            save_pc(model_save_path, pc)
        end
    end

    for iter = 1 : em_finetune_iters

        @printf("      - Fine-tune iter %d:\n", iter)
        
        if data_subsample_frac < 1.0
            curr_data = bagging_dataset(train_data; num_bags = 1, frac_examples = data_subsample_frac)[1]
            curr_val_data = bagging_dataset(valid_data; num_bags = 1, frac_examples = data_subsample_frac)[1]
        else
            curr_data = train_data
            curr_val_data = valid_data
        end

        # To estimate circuit parameters under the existence of latent/hidden variables, we employ the 
        # Expectation-Maximization (EM) algorithm.
        print("        - Estimating parameters with EM")
        t = @elapsed begin
            data = batch(curr_data, batch_size)
            if train_mode == "normal"
                estimate_parameters_em(pc, data; pseudocount, use_gpu, exp_update_factor = 0.0, entropy_reg)
            elseif train_mode == "categorical"
                estimate_parameters_em_cat(pc, data; pseudocount, use_gpu, exp_update_factor = 0.0,
                                           leaf_node_distribution, std_max, frac, n_variables)
            else
                error("Unknown mode: $(train_mode)")
            end
        end
        @printf(" (%.3fs)\n", t)

        # Evaluate
        print("        - Evaluating marginal log-likelihoods")
        t = @elapsed begin
            if train_mode == "normal"
                ll_train = marginal_log_likelihood_avg(pc, batch(curr_data, batch_size); use_gpu)
                ll_valid = marginal_log_likelihood_avg(pc, batch(curr_val_data, batch_size); use_gpu)
            elseif train_mode == "categorical"
                ll_train = marginal_log_likelihood_avg_cat(pc, batch(curr_data, batch_size); use_gpu, n_variables)
                ll_valid = marginal_log_likelihood_avg_cat(pc, batch(curr_val_data, batch_size); use_gpu, n_variables)
                bpd_train = -ll_train * log(MathConstants.e) / log(2) / num_features(train_data)
                bpd_valid = -ll_valid * log(MathConstants.e) / log(2) / num_features(train_data)
            else
                error("Unknown mode: $(train_mode)")
            end
        end
        @printf(" (%.3fs)\n", t)
        @printf("        - Train: \033[1m%.3f\033[0m; Valid: \033[1m%.3f\033[0m\n", 
                ll_train, ll_valid)
        if train_mode == "categorical"
            @printf("        - Train bpd: \033[1m%.3f\033[0m; Valid bpd: \033[1m%.3f\033[0m\n", 
                    bpd_train, bpd_valid)
            if log_file_name !== nothing
                open("logs/" * log_file_name, "a+") do f
                    write(f, "Finetune iteration $(iter) - Train bpd: $(bpd_train); Valid bpd: $(bpd_valid)\n")
                end
            end
        end
        
        push!(train_lls, ll_train)
        push!(valid_lls, ll_valid)
        
        if iter % model_save_interval == 0 && model_save_path !== nothing
            save_pc(model_save_path, pc)
        end
    end
    
    for iter = 1 : sgd_iters

        @printf("      - SGD iter %d:\n", iter)
        
        if data_subsample_frac < 1.0
            curr_data = bagging_dataset(train_data; num_bags = 1, frac_examples = data_subsample_frac)[1]
            curr_val_data = bagging_dataset(valid_data; num_bags = 1, frac_examples = data_subsample_frac)[1]
        else
            curr_data = train_data
            curr_val_data = valid_data
        end

        # To estimate circuit parameters under the existence of latent/hidden variables, we employ the 
        # Expectation-Maximization (EM) algorithm.
        print("        - Estimating parameters with EM")
        t = @elapsed begin
            data = batch(curr_data, batch_size)
            lr = sgd_lr_start + (iter - 1) * (sgd_lr_end - sgd_lr_start) / (sgd_iters - 1)
            if train_mode == "normal"
                sgd_parameter_learning(pc, data; lr, use_gpu)
            elseif train_mode == "categorical"
                sgd_cat(pc, data; lr, use_gpu)
            else
                error("Unknown mode: $(train_mode)")
            end
        end
        @printf(" (%.3fs)\n", t)

        # Evaluate
        print("        - Evaluating marginal log-likelihoods")
        t = @elapsed begin
            if train_mode == "normal"
                ll_train = marginal_log_likelihood_avg(pc, batch(curr_data, batch_size); use_gpu)
                ll_valid = marginal_log_likelihood_avg(pc, batch(curr_val_data, batch_size); use_gpu)
            elseif train_mode == "categorical"
                ll_train = marginal_log_likelihood_avg_cat(pc, batch(curr_data, batch_size); use_gpu)
                ll_valid = marginal_log_likelihood_avg_cat(pc, batch(curr_val_data, batch_size); use_gpu)
                bpd_train = -ll_train * log(MathConstants.e) / log(2) / num_features(train_data)
                bpd_valid = -ll_valid * log(MathConstants.e) / log(2) / num_features(train_data)
            else
                error("Unknown mode: $(train_mode)")
            end
        end
        @printf(" (%.3fs)\n", t)
        @printf("        - Train: \033[1m%.3f\033[0m; Valid: \033[1m%.3f\033[0m\n", 
                ll_train, ll_valid)
        if train_mode == "categorical"
            @printf("        - Train bpd: \033[1m%.3f\033[0m; Valid bpd: \033[1m%.3f\033[0m\n", 
                    bpd_train, bpd_valid)
            if log_file_name !== nothing
                open("logs/" * log_file_name, "a+") do f
                    write(f, "SGD iteration $(iter) - Train bpd: $(bpd_train); Valid bpd: $(bpd_valid)\n")
                end
            end
        end
        
        push!(train_lls, ll_train)
        push!(valid_lls, ll_valid)
        
        if iter % model_save_interval == 0 && model_save_path !== nothing
            save_pc(model_save_path, pc)
        end
    end
    
    return pc
end