using MLDatasets
using DataFrames
using NPZ


DATASET_NAMES = [
    "MNIST",
    "FashionMNIST",
    "EMNIST",
    "KMNIST"
]


function load_dataset(dataset_name = "MNIST"; flatten = false, precision = 8)
    if dataset_name == "MNIST"
        load_mnist(; flatten, precision)
    elseif dataset_name == "FashionMNIST"
        load_fashionmnist(; flatten, precision)
    elseif dataset_name == "EMNIST"
        load_emnist(; flatten, precision)
    elseif dataset_name == "KMNIST"
        load_kmnist(; flatten, precision)
    elseif dataset_name == "EMNIST_digits"
        load_emnist_digits(; flatten, precision)
    elseif dataset_name == "EMNIST_balanced"
        load_emnist_balanced(; flatten, precision)
    elseif dataset_name == "EMNIST_byclass"
        load_emnist_byclass(; flatten, precision)
    end
end


function load_mnist(; flatten = false, precision = 8)
    train_data = convert(Matrix{UInt8}, transpose(reshape(MNIST.traintensor(UInt8), 28 * 28, :)))
    test_data = convert(Matrix{UInt8}, transpose(reshape(MNIST.testtensor(UInt8), 28 * 28, :)))
    
    convert_data(data::Matrix{UInt32}; precision = 8) = begin
        d = BitMatrix(undef, size(data, 1), size(data, 2) * precision)
        for sample_idx = 1 : size(data, 1)
            for feature_idx = 1 : size(data, 2)
                val::UInt32 = data[sample_idx, feature_idx]
                val = val >> (8 - precision)
                for idx = 1 : precision
                    @inbounds d[sample_idx, precision*(feature_idx-1)+idx] = (val % 2)
                    val = val ÷ 2
                end
            end
        end
        DataFrame(d)
    end
    
    train_data = convert(Matrix{UInt32}, train_data) .+ UInt32(1)
    test_data = convert(Matrix{UInt32}, test_data) .+ UInt32(1)
    
    if flatten
        train_data = convert_data(train_data; precision)
        test_data = convert_data(test_data; precision)
    else
        train_data = DataFrame(train_data)
        test_data = DataFrame(test_data)
    end
    
    train_data, test_data
end


function load_fashionmnist(; flatten = false, precision = 8)
    train_data = convert(Matrix{UInt8}, transpose(reshape(FashionMNIST.traintensor(UInt8), 28 * 28, :)))
    test_data = convert(Matrix{UInt8}, transpose(reshape(FashionMNIST.testtensor(UInt8), 28 * 28, :)))
    
    convert_data(data::Matrix{UInt32}; precision = 8) = begin
        d = BitMatrix(undef, size(data, 1), size(data, 2) * precision)
        for sample_idx = 1 : size(data, 1)
            for feature_idx = 1 : size(data, 2)
                val::UInt32 = data[sample_idx, feature_idx]
                val = val >> (8 - precision)
                for idx = 1 : precision
                    @inbounds d[sample_idx, precision*(feature_idx-1)+idx] = (val % 2)
                    val = val ÷ 2
                end
            end
        end
        DataFrame(d)
    end
    
    train_data = convert(Matrix{UInt32}, train_data) .+ UInt32(1)
    test_data = convert(Matrix{UInt32}, test_data) .+ UInt32(1)
    
    if flatten
        train_data = convert_data(train_data; precision)
        test_data = convert_data(test_data; precision)
    else
        train_data = DataFrame(train_data)
        test_data = DataFrame(test_data)
    end
    
    train_data, test_data
end


function load_emnist(; flatten = false, precision = 8)
    train_data = convert(Matrix{UInt8}, reshape(EMNIST.Letters.traindata(), :, 28 * 28))
    test_data = convert(Matrix{UInt8}, reshape(EMNIST.Letters.testdata(), :, 28 * 28))
    
    convert_data(data::Matrix{UInt32}; precision = 8) = begin
        d = BitMatrix(undef, size(data, 1), size(data, 2) * precision)
        for sample_idx = 1 : size(data, 1)
            for feature_idx = 1 : size(data, 2)
                val::UInt32 = data[sample_idx, feature_idx]
                val = val >> (8 - precision)
                for idx = 1 : precision
                    @inbounds d[sample_idx, precision*(feature_idx-1)+idx] = (val % 2)
                    val = val ÷ 2
                end
            end
        end
        DataFrame(d)
    end
    
    train_data = convert(Matrix{UInt32}, train_data) .+ UInt32(1)
    test_data = convert(Matrix{UInt32}, test_data) .+ UInt32(1)
    
    if flatten
        train_data = convert_data(train_data; precision)
        test_data = convert_data(test_data; precision)
    else
        train_data = DataFrame(train_data)
        test_data = DataFrame(test_data)
    end
    
    train_data, test_data
end


function load_emnist_digits(; flatten = false, precision = 8)
    train_data = convert(Matrix{UInt8}, reshape(EMNIST.Digits.traindata(), :, 28 * 28))
    test_data = convert(Matrix{UInt8}, reshape(EMNIST.Digits.testdata(), :, 28 * 28))
    
    convert_data(data::Matrix{UInt32}; precision = 8) = begin
        d = BitMatrix(undef, size(data, 1), size(data, 2) * precision)
        for sample_idx = 1 : size(data, 1)
            for feature_idx = 1 : size(data, 2)
                val::UInt32 = data[sample_idx, feature_idx]
                val = val >> (8 - precision)
                for idx = 1 : precision
                    @inbounds d[sample_idx, precision*(feature_idx-1)+idx] = (val % 2)
                    val = val ÷ 2
                end
            end
        end
        DataFrame(d)
    end
    
    train_data = convert(Matrix{UInt32}, train_data) .+ UInt32(1)
    test_data = convert(Matrix{UInt32}, test_data) .+ UInt32(1)
    
    if flatten
        train_data = convert_data(train_data; precision)
        test_data = convert_data(test_data; precision)
    else
        train_data = DataFrame(train_data)
        test_data = DataFrame(test_data)
    end
    
    train_data, test_data
end


function load_emnist_balanced(; flatten = false, precision = 8)
    train_data = convert(Matrix{UInt8}, reshape(EMNIST.Balanced.traindata(), :, 28 * 28))
    test_data = convert(Matrix{UInt8}, reshape(EMNIST.Balanced.testdata(), :, 28 * 28))
    
    convert_data(data::Matrix{UInt32}; precision = 8) = begin
        d = BitMatrix(undef, size(data, 1), size(data, 2) * precision)
        for sample_idx = 1 : size(data, 1)
            for feature_idx = 1 : size(data, 2)
                val::UInt32 = data[sample_idx, feature_idx]
                val = val >> (8 - precision)
                for idx = 1 : precision
                    @inbounds d[sample_idx, precision*(feature_idx-1)+idx] = (val % 2)
                    val = val ÷ 2
                end
            end
        end
        DataFrame(d)
    end
    
    train_data = convert(Matrix{UInt32}, train_data) .+ UInt32(1)
    test_data = convert(Matrix{UInt32}, test_data) .+ UInt32(1)
    
    if flatten
        train_data = convert_data(train_data; precision)
        test_data = convert_data(test_data; precision)
    else
        train_data = DataFrame(train_data)
        test_data = DataFrame(test_data)
    end
    
    train_data, test_data
end

function load_emnist_byclass(; flatten = false, precision = 8)
    train_data = convert(Matrix{UInt8}, reshape(EMNIST.ByClass.traindata(), :, 28 * 28))
    test_data = convert(Matrix{UInt8}, reshape(EMNIST.ByClass.testdata(), :, 28 * 28))
    
    convert_data(data::Matrix{UInt32}; precision = 8) = begin
        d = BitMatrix(undef, size(data, 1), size(data, 2) * precision)
        for sample_idx = 1 : size(data, 1)
            for feature_idx = 1 : size(data, 2)
                val::UInt32 = data[sample_idx, feature_idx]
                val = val >> (8 - precision)
                for idx = 1 : precision
                    @inbounds d[sample_idx, precision*(feature_idx-1)+idx] = (val % 2)
                    val = val ÷ 2
                end
            end
        end
        DataFrame(d)
    end
    
    train_data = convert(Matrix{UInt32}, train_data) .+ UInt32(1)
    test_data = convert(Matrix{UInt32}, test_data) .+ UInt32(1)
    
    if flatten
        train_data = convert_data(train_data; precision)
        test_data = convert_data(test_data; precision)
    else
        train_data = DataFrame(train_data)
        test_data = DataFrame(test_data)
    end
    
    train_data, test_data
end


function load_kmnist(; flatten = false, precision = 8)
    train_data = convert(Matrix{UInt8}, reshape(NPZ.npzread("data/KMNIST/kmnist-train-imgs.npz")["arr_0"], :, 28 * 28))
    test_data = convert(Matrix{UInt8}, reshape(NPZ.npzread("data/KMNIST/kmnist-test-imgs.npz")["arr_0"], :, 28 * 28))
    
    convert_data(data::Matrix{UInt32}; precision = 8) = begin
        d = BitMatrix(undef, size(data, 1), size(data, 2) * precision)
        for sample_idx = 1 : size(data, 1)
            for feature_idx = 1 : size(data, 2)
                val::UInt32 = data[sample_idx, feature_idx]
                val = val >> (8 - precision)
                for idx = 1 : precision
                    @inbounds d[sample_idx, precision*(feature_idx-1)+idx] = (val % 2)
                    val = val ÷ 2
                end
            end
        end
        DataFrame(d)
    end
    
    train_data = convert(Matrix{UInt32}, train_data) .+ UInt32(1)
    test_data = convert(Matrix{UInt32}, test_data) .+ UInt32(1)
    
    if flatten
        train_data = convert_data(train_data; precision)
        test_data = convert_data(test_data; precision)
    else
        train_data = DataFrame(train_data)
        test_data = DataFrame(test_data)
    end
    
    train_data, test_data
end