using LogicCircuits
using ProbabilisticCircuits
using DataFrames

include("../src/Juice-compression.jl")
include("./datasets/load_datasets.jl")


function main(dataset_name = "MNIST"; gpu_idx = 0, bin_precision = 1, num_hidden_cats = 16, 
              leaf_node_distribution = "categorical", pseudocount = 0.25)
    # Load dataset
    print("Loading dataset $(dataset_name)")
    t = @elapsed begin
        train_data, test_data = load_dataset(dataset_name; flatten = false, precision = 8)
        bin_data, _ = load_dataset(dataset_name; flatten = true, precision = bin_precision)
        
        num_vars = num_features(train_data)
        num_cats = 2^8
    end
    @printf(" (%.3fs)\n", t)
    println("> Features: $(num_vars); Train Examples: $(num_examples(train_data)); Test Examples: $(num_examples(test_data))")
    
    # HCLT hyperparameters
    num_trees = 1
    num_tree_candidates = 1
    tree_sample_type = "fixed_interval"
    dropout_prob = 0.0
    
    print("Generating `hidden_chow_liu_circuit`")
    t = @elapsed begin
        pc = hidden_chow_liu_circuit_cat(num_vars, num_cats; data = train_data, num_hidden_cats,
                                         num_trees, num_tree_candidates, tree_sample_type, dropout_prob,
                                         data_for_mi = (bin_data, num_vars, 2^bin_precision))

        # Add uniform prior to the circuit
        uniform_parameters_cat(pc; perturbation = 0.5)
        # kmeans_params_initialization_cat(pc, train_data)
    end
    @printf(" (%.3fs)\n", t)
    println("> Edges: $(num_edges(pc)); Parameters: $(num_parameters(pc))")
    
    # EM parameters
    soft_reg = 0.0
    entropy_reg = 0.0
    
    log_file_name = "HCLT_$(dataset_name)_$(num_hidden_cats)_$(leaf_node_distribution)_$(pseudocount).txt"
    
    open("logs/" * log_file_name, "a+") do f
        write(f, "Edges: $(num_edges(pc)); Parameters: $(num_parameters(pc))\n")
    end
    
    # Training
    pc = train_em(pc, train_data, test_data; gpu_idx, pseudocount, soft_reg, entropy_reg, 
                  em_warmup_iters = 120, em_finetune_iters = 20, train_mode = "categorical",
                  exp_update_factor_start = 0.85, exp_update_factor_end = 0.95, 
                  batch_size = 1024, log_file_name, leaf_node_distribution,
                  model_save_path = "$(dataset_name).pc", model_save_interval = 10)
end


gpu_idx = 1
bin_precision = 3

num_hidden_cats = 32

for dataset_name in ["FashionMNIST", "EMNIST", "EMNIST_byclass"]
    main(dataset_name; gpu_idx, bin_precision, num_hidden_cats, leaf_node_distribution = "categorical")
end

#=num_hidden_cats = 8

for dataset_name in DATASET_NAMES
    main(dataset_name; gpu_idx, bin_precision, num_hidden_cats, leaf_node_distribution = "sparse")
end

num_hidden_cats = 12

for dataset_name in DATASET_NAMES
    main(dataset_name; gpu_idx, bin_precision, num_hidden_cats, leaf_node_distribution = "sparse")
end

num_hidden_cats = 14

for dataset_name in DATASET_NAMES
    main(dataset_name; gpu_idx, bin_precision, num_hidden_cats, leaf_node_distribution = "sparse")
end=#

# num_hidden_cats = 32
# main("EMNIST_balanced"; gpu_idx, bin_precision, num_hidden_cats, leaf_node_distribution = "categorical")

#=num_hidden_cats = 42
pseudocount = 0.25
main("FashionMNIST"; gpu_idx, bin_precision, num_hidden_cats, leaf_node_distribution = "categorical", pseudocount)

num_hidden_cats = 42
pseudocount = 0.5
main("FashionMNIST"; gpu_idx, bin_precision, num_hidden_cats, leaf_node_distribution = "categorical", pseudocount)

num_hidden_cats = 42
pseudocount = 1.0
main("FashionMNIST"; gpu_idx, bin_precision, num_hidden_cats, leaf_node_distribution = "categorical", pseudocount)

num_hidden_cats = 42
pseudocount = 2.0
main("FashionMNIST"; gpu_idx, bin_precision, num_hidden_cats, leaf_node_distribution = "categorical", pseudocount)

num_hidden_cats = 48
pseudocount = 0.5
main("FashionMNIST"; gpu_idx, bin_precision, num_hidden_cats, leaf_node_distribution = "categorical", pseudocount)

num_hidden_cats = 48
pseudocount = 1.0
main("FashionMNIST"; gpu_idx, bin_precision, num_hidden_cats, leaf_node_distribution = "categorical", pseudocount)

num_hidden_cats = 48
pseudocount = 2.0
main("FashionMNIST"; gpu_idx, bin_precision, num_hidden_cats, leaf_node_distribution = "categorical", pseudocount)=#