include("../src/Juice-compression.jl");
include("load_mnist.jl");

print("Loading dataset mnist");
t = @elapsed begin
    train_data, test_data = load_mnist(; flatten = false);
    bin_mnist, _, _ = twenty_datasets("binarized_mnist");

    num_vars = num_features(train_data);
    num_cats = 2^8;
end
@printf(" (%.3fs)\n", t);
println("> Features: $(num_vars); Train Examples: $(num_examples(train_data)); Test Examples: $(num_examples(test_data))");

select_gpu_device();

pc = load_pc("MNIST.pc"); # PC generated by train_hclt.jl

pbc = CompressParamBitCircuit(pc);
pbc = to_gpu(pbc);

data = deepcopy(test_data[1:10000,:]);
data = to_gpu(data);

code, reuse_enc = compress_data_rans(pbc, data; precision = 30, use_gpu = true);
code, reuse_enc = compress_data_rans(pbc, data, reuse_enc; precision = 30, use_gpu = true);
@time code, reuse_enc = compress_data_rans(pbc, data, reuse_enc; precision = 30, use_gpu = true);
# 16: 14.697904 seconds (269.67 M allocations: 4.236 GiB, 9.67% gc time)
# 24: 25.931284 seconds (305.73 M allocations: 4.773 GiB, 2.84% gc time)
# 32: 43.547312 seconds (358.30 M allocations: 5.556 GiB, 1.11% gc time)

d, reuse_dec = decompress_marginal_rans(pbc, code; precision = 30, use_gpu = true);
d, reuse_dec = decompress_marginal_rans(pbc, code, d, reuse_dec; precision = 30, use_gpu = true);
@time d, reuse_dec = decompress_marginal_rans(pbc, code, d, reuse_dec; precision = 30, use_gpu = true);
# 16: 44.026582 seconds (125.64 M allocations: 3.296 GiB, 0.59% gc time)
# 24: 88.709477 seconds (262.94 M allocations: 5.341 GiB, 0.51% gc time)
# 32: 155.176581 seconds (456.10 M allocations: 8.220 GiB, 0.49% gc time)

mapreduce(length, +, code) / num_examples(data) / 28 / 28
# 16: 1.3056572704081633
# 24: 1.2653904336734694
# 32: 1.2450542091836736

-marginal_log_likelihood_avg_cat(pc, data; use_gpu = false) * log(MathConstants.e) / log(2.0) / 28 / 28
# 16: 1.2640722812250218
# 24: 1.223837327071986
# 32: 1.2034889262999828

correct = convert(Matrix{UInt32}, to_cpu(data)) .== to_cpu(d);
