/* Command line to build and run on x86:

c++ doc/quantization_example.cc -I . --std=c++11 -msse4.1 -lpthread \
    -o /tmp/quantization_example && \
/tmp/quantization_example

*/

#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdint>
#include <iostream>
#include <random>
#include <vector>
#include <cstring>
#include <cstdlib>
#include <ctime>
#include "../public/gemmlowp.h"
//#include "matrix_storage.h"
#include "matrix_multiplication.h"

// We will handle both float and quantized matrices, which we will
// represent as gemmlowp::MatrixMap.
// We will need to be able to print them.

// Output a matrix to a std::ostream

constexpr auto kDefaultRows = 200;
constexpr auto kDefaultCols = 200;
constexpr auto kDefaultDepth = 200;
const auto kRowOrder = gemmlowp::MapOrder::RowMajor;
const auto kColOrder = gemmlowp::MapOrder::ColMajor;


int main(int argc, char** argv) {
    std::cout.precision(15);
    bool debug_mode = false; 

    for (int i = 1; i < argc; i+= 2) {
        if (std::string(argv[i]) == "--debug") {
            debug_mode = true;
            i -= 1;
        }
    }

    const int rows = 4;
    const int depth = 4;
    const int cols = 4;

    
        
    // Initializr int8 result
    PackMatrixWithStorage<int8_t, int, kRowOrder> int8_trinary_lhs(rows, depth, 7);
    int8_trinary_lhs.MakeTrinaryRandom();
    int8_trinary_lhs.PrepareAuxMatrixMap();
    PackMatrixWithStorage<int8_t, int, kColOrder> int8_rhs(depth, cols, 7);
    int8_rhs.MakeInt8Random();

    PackMatrixWithStorage<int8_t, int, kRowOrder> reference_int8_result(rows, cols, 8);
    PackMatrixWithStorage<int8_t, int, kRowOrder> pack_int8_result(rows, cols, 8);
    auto reference_int8_result_map = reference_int8_result.Map();


    if (debug_mode) {
        std::cout << "Here is the int8 LHS matrix:\n" << int8_trinary_lhs << std::endl;
        std::cout << "Here is the int8 LHS B1rep matrix:\n" << 
            int8_trinary_lhs.ConstMap(MapOption::B1REP) << std::endl;
        std::cout << "Here is the int8 LHS B2rep matrix:\n" << 
            int8_trinary_lhs.ConstMap(MapOption::B2REP) << std::endl;
        std::cout << "Here is the int8 LHS B2first matrix:\n" << 
            int8_trinary_lhs.ConstMap(MapOption::B2FIRST) << std::endl;
        std::cout << "Here is the int8 RHS matrix:\n" << int8_rhs << std::endl;


        std::cout << "Here is the int32 LHS matrix:\n" << int8_trinary_lhs.ConstMapPack() << std::endl;
        std::cout << "Here is the int32 LHS B1rep matrix:\n" << int8_trinary_lhs.ConstMapPack(MapOption::B1REP) << std::endl;
        std::cout << "Here is the int32 LHS B2rep matrix:\n" << int8_trinary_lhs.ConstMapPack(MapOption::B2REP) << std::endl;
        std::cout << "Here is the int32 LHS B2first matrix:\n" << int8_trinary_lhs.ConstMapPack(MapOption::B2FIRST) << std::endl;
        std::cout << "Here is the int32 RHS matrix:\n" << int8_rhs.ConstMapPack() << std::endl;
    }
    clock_t start = clock();
    NormalMatrixMultiplication(int8_trinary_lhs.ConstMap(), int8_rhs.ConstMap(),
                                                       &reference_int8_result_map);
    clock_t b1 = clock();
    int8_rhs.ApplyActivationMask();
    clock_t b2 = clock();
    PackMatrixMultiplicationMask(int8_trinary_lhs, int8_rhs, &pack_int8_result);

    clock_t b3 = clock();

    if (debug_mode) {
        std::cout << "Here is the int8 product (LHS * RHS) matrix obtained by "
                            << "pack matrix multiplication\n"
                            << pack_int8_result << std::endl;
        std::cout << "Here is the int8 product (LHS * RHS) matrix obtained by "
                            << "ordinary matrix multiplication, i.e. as far as we are "
                            << "concerned, the REFERENCE RESULT:\n"
                            << reference_int8_result << std::endl;
    }

    std::cout << "Normal MM takes " << (double)(b1 - start)/CLOCKS_PER_SEC << " secs\n";
    std::cout << "Pack MM takes " << (double)(b3 - b2)/CLOCKS_PER_SEC << " secs\n";

    
    double diff_sum = 0.0;
    for (int i = 0; i < rows; i++) {
        for (int j = 0; j < cols; j++) {
            diff_sum += abs(reference_int8_result.Map()(i, j) - pack_int8_result.Map()(i, j));
        }
    }
    
    std::cout << "Diff sum " << diff_sum << "\n";
    std::cout << "The absolute diff per entry is " << diff_sum / (rows * cols) << "\n";
    
    return 0;
}
