/* Command line to build and run on x86:

c++ doc/quantization_example.cc -I . --std=c++11 -msse4.1 -lpthread \
    -o /tmp/quantization_example && \
/tmp/quantization_example

*/

#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdint>
#include <iostream>
#include <random>
#include <vector>
#include <cstring>
#include <cstdlib>
#include <ctime>
#include "../gemmlowp/public/gemmlowp.h"
#include "layer.h"
#include "matrix_storage.h"
#include "matrix_multiplication.h"
#include "utils.h"

// We will handle both float and quantized matrices, which we will
// represent as gemmlowp::MatrixMap.
// We will need to be able to print them.

// Output a matrix to a std::ostream

bool arg_debug_mode = false;

constexpr auto kDefaultRows = 200;
constexpr auto kDefaultCols = 200;
constexpr auto kDefaultDepth = 200;
const auto kRowOrder = gemmlowp::MapOrder::RowMajor;
const auto kColOrder = gemmlowp::MapOrder::ColMajor;
constexpr auto kDefaultActivationBit = 7;



void TestMatrixMultiplication(int rows, int depth, int cols) {
    // Initializr int8 result
    PackMatrixWithStorage<int8_t, int8_t, kRowOrder> int8_trinary_lhs(rows, depth, 7);
    int8_trinary_lhs.MakeTrinaryRandom();
    int8_trinary_lhs.PrepareAuxMatrixMap();
    PackMatrixWithStorage<int8_t, int8_t, kColOrder> int8_rhs(depth, cols, 7);
    int8_rhs.MakeInt8Random();

    PackMatrixWithStorage<int8_t, int8_t, kRowOrder> reference_int8_result(rows, cols, 8);
    PackMatrixWithStorage<int8_t, int8_t, kRowOrder> pack_int8_mask_result(rows, cols, 8);
    PackMatrixWithStorage<int8_t, int8_t, kRowOrder> pack_int8_result(rows, cols, 8);
    auto reference_int8_result_map = reference_int8_result.Map();


    if (arg_debug_mode) {
        std::cout << "Here is the int8 LHS matrix:\n" << int8_trinary_lhs << std::endl;
        std::cout << "Here is the int8 LHS B1rep matrix:\n" << 
            int8_trinary_lhs.ConstMap(MapOption::B1REP) << std::endl;
        std::cout << "Here is the int8 LHS B2rep matrix:\n" << 
            int8_trinary_lhs.ConstMap(MapOption::B2REP) << std::endl;
        std::cout << "Here is the int8 LHS B2first matrix:\n" << 
            int8_trinary_lhs.ConstMap(MapOption::B2FIRST) << std::endl;
        std::cout << "Here is the int8 RHS matrix:\n" << int8_rhs << std::endl;


        std::cout << "Here is the int32 LHS matrix:\n" << int8_trinary_lhs.ConstMapPack() << std::endl;
        std::cout << "Here is the int32 LHS B1rep matrix:\n" << int8_trinary_lhs.ConstMapPack(MapOption::B1REP) << std::endl;
        std::cout << "Here is the int32 LHS B2rep matrix:\n" << int8_trinary_lhs.ConstMapPack(MapOption::B2REP) << std::endl;
        std::cout << "Here is the int32 LHS B2first matrix:\n" << int8_trinary_lhs.ConstMapPack(MapOption::B2FIRST) << std::endl;
        std::cout << "Here is the int32 RHS matrix:\n" << int8_rhs.ConstMapPack() << std::endl;
    }

    float avg_normal = 0.0, avg_pack_full = 0.0, avg_pack_mask = 0.0;
    int avg_num = 5;
    
    for (int i = 0; i < 5; i++) { 
        clock_t start = clock();
        NormalMatrixMultiplication(int8_trinary_lhs.ConstMap(), int8_rhs.ConstMap(),
                                                       &reference_int8_result_map);
        clock_t b1 = clock();
        avg_normal += (float)(b1 - start);

        clock_t pm_start = clock();
        PackMatrixMultiplicationFull(int8_trinary_lhs, int8_rhs, &pack_int8_result);
        clock_t pm_end = clock(); 
    
        avg_pack_full += (float)(pm_end - pm_start); 

        //int8_rhs.ApplyActivationMask();
        clock_t pm_mask_start = clock();
        PackMatrixMultiplicationMask(int8_trinary_lhs, int8_rhs, &pack_int8_mask_result);

        clock_t pm_mask_end = clock();
        
        avg_pack_mask += (float)(pm_mask_end - pm_mask_start); 


    }

    if (arg_debug_mode) {
        std::cout << "Here is the int8 product (LHS * RHS) matrix obtained by "
                        << "pack matrix multiplication\n"
                            << pack_int8_result << std::endl;
        std::cout << "Here is the int8 product (LHS * RHS) matrix obtained by "
                        << "ordinary matrix multiplication, i.e. as far as we are "
                        << "concerned, the REFERENCE RESULT:\n"
                        << reference_int8_result << std::endl;
    }

    std::cout << "Normal MM takes " << (avg_normal*1000)/(CLOCKS_PER_SEC * avg_num) << " secs\n";
    //std::cout << "Pack MM takes " << (avg_pack_full)/(CLOCKS_PER_SEC * avg_num) << " secs\n";
    std::cout << "Pack Mask MM takes " << (avg_pack_mask*1000)/(CLOCKS_PER_SEC * avg_num) << " secs\n";

    
    double diff_sum_pm = 0.0;
    double diff_sum_pm_mask = 0.0;
    for (int i = 0; i < rows; i++) {
        for (int j = 0; j < cols; j++) {
            diff_sum_pm += abs(reference_int8_result.Map()(i, j) - pack_int8_result.Map()(i, j));
            diff_sum_pm_mask += abs(reference_int8_result.Map()(i, j) - pack_int8_mask_result.Map()(i, j));
        }
    }
    
    //std::cout << "Diff sum " << diff_sum << "\n";
    std::cout << "The absolute diff per entry for pm is " << diff_sum_pm / (rows * cols) << "\n";
    std::cout << "The absolute diff per entry for pm mask is " << diff_sum_pm_mask / (rows * cols) << "\n";

}

void TestImgToCol(int img_row, int img_col, int channel, int img_num) {
    auto img_mat = PackImageStorage<int8_t, int32_t, kRowOrder>(img_row, img_col, channel, img_num, kDefaultActivationBit);
    img_mat.FillValueWithIndex();
    int padding = 1;
    int stride = 2;
    int ks = 3;
    
    int row_patch_num = (img_row + 2 * padding - ks) / stride + 1;
    int col_patch_num = (img_col + 2 * padding - ks) / stride + 1;
    int patch_num = row_patch_num * col_patch_num;
    int output_row = ks * ks * channel;
    output_row = (output_row / 4 + (output_row % 4 > 0) * 1) * 4;


    auto img_to_col_mat = PackMatrixWithStorage<int8_t, int32_t, kColOrder>(output_row, patch_num * img_num, kDefaultActivationBit);
    
    clock_t img_to_col_start = clock();
    img_mat.Img2Col(padding, stride, ks, &img_to_col_mat);
    float img_to_col_time = (float)(clock() - img_to_col_start)/CLOCKS_PER_SEC;
    std::cout << "Imgae to column time is " << img_to_col_time << "\n";

    if (arg_debug_mode) {
        std::cout << "Image storage matrix is \n" << img_mat.ConstMap();
        std::cout << "Img2Col matrix is \n" << img_to_col_mat.ConstMap();
    }
}

void TestConvForward(int img_row, int img_col, int channel, int img_num) {

    auto img_mat = PackImageStorage<int8_t, int32_t, kRowOrder>(img_row, img_col, channel, img_num, kDefaultActivationBit);
    img_mat.FillValueWithIndex();
    int padding = 1;
    int stride = 2;
    int ks = 3;
    int kernel_num = 6;
    auto layer = Conv2DDefault(ks, channel, kernel_num, padding, stride, kDefaultActivationBit);
    layer.InitRandomWeight();

    auto dequant_layer = DeQuantDefault(1.5, 0);
    
    SetMatMulEngine(MMOPTION::NORMAL);
    clock_t normal_mm_start = clock();
    auto result = layer.Forward(img_mat);
    auto dequant_result = dequant_layer.Forward(result);
    float normal_mm_time = (float)(clock() - normal_mm_start)/CLOCKS_PER_SEC;
    std::cout << "Time result for nomral mat mul is " << normal_mm_time << "\n";
    if (arg_debug_mode) {
        std::cout << "Image storage matrix is \n" << img_mat.ConstMap();
        std::cout << "Img2col matrix is \n" << layer.GetImg2ColMatrix().ConstMap();
        std::cout << "Weight matrix is \n" << layer.GetWeightMatrix().ConstMap();
        std::cout << "Result matrix is \n" << result.ConstMap();
        std::cout << "DeQuant Result matrix is \n" << dequant_result.ConstMap();
    }
    SetMatMulEngine(MMOPTION::PACK);
    clock_t pack_mm_start = clock();
    result = layer.Forward(img_mat);
    float pack_mm_time = (float)(clock() - pack_mm_start)/CLOCKS_PER_SEC;
    std::cout << "Time result for pack mat mul is " << pack_mm_time << "\n";
    if (arg_debug_mode) {
        std::cout << "Image storage matrix is \n" << img_mat.ConstMap();
        std::cout << "Img2col matrix is \n" << layer.GetImg2ColMatrix().ConstMap();
        std::cout << "Weight matrix is \n" << layer.GetWeightMatrix().ConstMap();
        std::cout << "Result matrix is \n" << result.ConstMap();
    }

}

void TestReluForward(int img_row, int img_col, int channel, int img_num) {

    auto img_mat = PackImageStorage<float, float, kRowOrder>(img_row, img_col, channel, img_num, kDefaultActivationBit);
    img_mat.MakeRandomLarge();
    auto layer = ReluFloatDefault();
    auto quant_layer = QuantDefault(2, 0);

    clock_t relu_start = clock();
    auto result = layer.Forward(img_mat);
    // auto quant_result = quant_layer.Forward(result);
    auto quant_result = result;
    float relu_time = (float)(clock() - relu_start)/CLOCKS_PER_SEC;
    std::cout << "Time result for relu is " << relu_time << "\n";
    if (arg_debug_mode) {
        std::cout << "Image storage matrix is \n" << img_mat.ConstMap();
        std::cout << "Result matrix is \n" << result.ConstMap();
        // std::cout << "Quant result matrix is \n" << quant_result.ConstMap();
    }

}

void TestBNForward(int img_row, int img_col, int channel, int img_num) {
    auto img_mat = PackImageStorage<float, float, kRowOrder>(img_row, img_col, channel, img_num, kDefaultActivationBit);
    img_mat.MakeRandomLarge();
    auto layer = BNDefault(channel);
    layer.InitRandomParam(img_mat);
    auto quant_layer = QuantDefault(1, 0);

    clock_t bn_start = clock();
    auto result = layer.Forward(img_mat);
    auto quant_result = quant_layer.Forward(result);
    float bn_time = (float)(clock() - bn_start)/CLOCKS_PER_SEC;
    std::cout << "Time result for BN is " << bn_time << "\n";

    if (arg_debug_mode) {
        std::cout << "Image storage matrix is \n" << img_mat.ConstMap();
        std::cout << "Weight mat is \n" << layer.GetWeightMatrix().ConstMap();
        std::cout << "Bias mat is \n" << layer.GetBiasMatrix().ConstMap();
        std::cout << "Mean mat is \n" << layer.GetMeanMatrix().ConstMap();
        std::cout << "Std mat is \n" << layer.GetStdMatrix().ConstMap();
        std::cout << "Result matrix is \n" << result.ConstMap();
        std::cout << "Quant result matrix is \n" << quant_result.ConstMap();
    }
}

void TestMaxPoolingForward(int img_row, int img_col, int channel, int img_num) {
    auto img_mat = PackImageStorage<int8_t, int32_t, kRowOrder>(img_row, img_col, channel, img_num, kDefaultActivationBit);
    // img_mat.FillValueWithIndex();
    img_mat.MakeInt8Random();
    auto layer = MaxPoolingDefault(2);

    clock_t mp_start = clock();
    auto result = layer.Forward(img_mat);
    float mp_time = (float)(clock() - mp_start)/CLOCKS_PER_SEC;
    std::cout << "Time result for max pooling is " << mp_time << "\n";

    if (arg_debug_mode) {
        std::cout << "Image storage matrix is \n" << img_mat.ConstMap();
        std::cout << "Result matrix is \n" << result.ConstMap();
    }
}

void TestReadMatrix() {
    std::string path_1 = "./src/test_file/m1";
    std::string path_2 = "./src/test_file/m2";
    PackMatrixWithStorage<int8_t, int32_t, kRowOrder> mat_i(path_1, 8);
    PackMatrixWithStorage<float, float, kRowOrder> mat_f(path_2, 8);
    if (arg_debug_mode) {
        std::cout << mat_i << "\n";
        std::cout << mat_f;
    }
}

void TestReadConvWeight() {
    int padding = 0;
    int stride = 1;
    int ks = 3;
    int kernel_num = 128;
    int channel = 3;
    auto fp_layer = Conv2DFloatDefault(ks, channel, kernel_num, padding, stride, kDefaultActivationBit);
    

    channel = 128;
    kernel_num = 128;
    auto int_layer = Conv2DDefault(ks, channel, kernel_num, padding, stride, kDefaultActivationBit);
    int_layer.InitPackMatMul();

    std::string path_1 = "./model/cpp/vgg_bwn_cyc/conv1";
    std::string path_2 = "./model/cpp/vgg_bwn_cyc/conv2";

    fp_layer.InitWeightFromFile(path_1);
    int_layer.InitWeightFromFile(path_2);

    if (arg_debug_mode) {
        std::cout << path_1 << " result is\n" << fp_layer.GetWeightMatrix();
        std::cout << "\n\n\n";
        std::cout << path_2 << " result is\n" << int_layer.GetWeightMatrix();
    }

}

void TestReadBNParam() {
    int bn_feautre_dim = 128;
    auto layer = BNDefault(bn_feautre_dim);
    std::string path_1 = "./model/cpp/vgg_bwn_cyc/bn1";
    layer.InitParamWithFile(path_1);
    if (arg_debug_mode) {
        std::cout << "Weight mat is \n" << layer.GetWeightMatrix().ConstMap();
        std::cout << "Bias mat is \n" << layer.GetBiasMatrix().ConstMap();
        std::cout << "Mean mat is \n" << layer.GetMeanMatrix().ConstMap();
        std::cout << "Std mat is \n" << layer.GetStdMatrix().ConstMap();
    }

}

void TestFlatten(int img_row, int img_col, int channel, int img_num) {

    auto img_mat_int = PackImageStorage<int8_t, int32_t, kRowOrder>(img_row, img_col, channel, img_num, kDefaultActivationBit);
    img_mat_int.FillValueWithIndex();
    auto img_mat_fp = PackImageStorage<float, float, kRowOrder>(img_row, img_col, channel, img_num, kDefaultActivationBit);
    img_mat_fp.FillValueWithIndex();

    if (arg_debug_mode) {
        std::cout << "image int original result:\n" << img_mat_int;
        std::cout << "image int flatten result:\n" << img_mat_int.Flatten();
        std::cout << "image fp original result:\n" << img_mat_fp;
        std::cout << "image fp flatten result:\n" << img_mat_fp.Flatten();
    }

}

int main(int argc, char** argv) {
    std::cout.precision(5);

    int rows = 200;
    int depth = 200;
    int cols = 200;

    for (int i = 1; i < argc; i+= 2) {
        if (std::string(argv[i]) == "--debug") {
            arg_debug_mode = true;
            i -= 1;
        }
        if (std::string(argv[i]) == "--rows") {
            rows = std::stoi(argv[i+1]);
        }
        if (std::string(argv[i]) == "--cols") {
            cols = std::stoi(argv[i+1]);
        }
        if (std::string(argv[i]) == "--depth") {
            depth = std::stoi(argv[i+1]);
        }
    }


    TestMatrixMultiplication(rows, depth, cols); 

    const int img_row = 2;
    const int img_col = 2;
    const int img_num = 2;
    const int channel = 2;
    // TestImgToCol(img_row, img_col, channel, img_num);
    // TestConvForward(img_row, img_col, channel, img_num);
    // TestReluForward(img_row, img_col, channel, img_num);
    // TestBNForward(img_row, img_col, channel, img_num);
    // TestMaxPoolingForward(img_row, img_col, channel, img_num);
    // TestReadMatrix();
    // TestReadConvWeight();
    // TestReadBNParam();
    // TestFlatten(img_row, img_col, img_num, channel);
    
        
    
    return 0;
}
