/* Command line to build and run on x86:

c++ doc/quantization_example.cc -I . --std=c++11 -msse4.1 -lpthread \
    -o /tmp/quantization_example && \
/tmp/quantization_example

*/

#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdint>
#include <iostream>
#include <random>
#include <memory>
#include <vector>
#include <cstring>
#include <cstdlib>
#include <ctime>
#include "../gemmlowp/public/gemmlowp.h"
#include "layer.h"
#include "matrix_storage.h"
#include "matrix_multiplication.h"
#include "utils.h"

// We will handle both float and quantized matrices, which we will
// represent as gemmlowp::MatrixMap.
// We will need to be able to print them.

// Output a matrix to a std::ostream

bool arg_debug_mode = false;

constexpr auto kDefaultRows = 200;
constexpr auto kDefaultCols = 200;
constexpr auto kDefaultDepth = 200;
const auto kRowOrder = gemmlowp::MapOrder::RowMajor;
const auto kColOrder = gemmlowp::MapOrder::ColMajor;
constexpr auto kDefaultActivationBit = 7;


void VGG(int img_num, const std::string& model_path) {
    int img_row = 32;
    int img_col = 32;
    int channel = 3;
    

    auto img_mat = PackImageStorage<float, float, kRowOrder>(img_row, img_col, channel, img_num, kDefaultActivationBit);
    img_mat.MakeRandomPositive();    

    // Layer def
    auto conv1 = Conv2DFloatDefault(/*ks=*/3, /*channel=*/3, /*kernel_num=*/128, /*padding=*/0, /*stride=*/1, kDefaultActivationBit); 
    auto bn1 = BNDefault(128);
    auto q1 = QuantDefault(1.8, 0);

    auto conv2 = Conv2DDefault(/*ks=*/3, /*channel=*/128, /*kernel_num=*/128, /*padding=*/0, /*stride=*/1, kDefaultActivationBit); 
    auto dq2 = DeQuantDefault(1.0, 0);
    auto bn2 = BNDefault(128);
    auto q2 = QuantDefault(2.2, 0);

    auto conv3 = Conv2DDefault(/*ks=*/3, /*channel=*/128, /*kernel_num=*/256, /*padding=*/0, /*stride=*/1, kDefaultActivationBit); 
    auto dq3 = DeQuantDefault(1.0, 0);
    auto bn3 = BNDefault(256);
    auto q3 = QuantDefault(2.3, 0);

    auto conv4 = Conv2DDefault(/*ks=*/3, /*channel=*/256, /*kernel_num=*/256, /*padding=*/0, /*stride=*/1, kDefaultActivationBit); 
    auto dq4 = DeQuantDefault(1.0, 0);
    auto bn4 = BNDefault(256);
    auto q4 = QuantDefault(2.6, 0);

    auto conv5 = Conv2DDefault(/*ks=*/3, /*channel=*/256, /*kernel_num=*/512, /*padding=*/0, /*stride=*/1, kDefaultActivationBit); 
    auto dq5 = DeQuantDefault(1.0, 0);
    auto bn5 = BNDefault(512);
    auto q5 = QuantDefault(1.8, 0);

    auto conv6 = Conv2DDefault(/*ks=*/3, /*channel=*/512, /*kernel_num=*/512, /*padding=*/1, /*stride=*/1, kDefaultActivationBit); 
    auto dq6 = DeQuantDefault(1.0, 0);
    auto bn6 = BNDefault(512);

    auto clf = LinearDefault(10, 512);
    
    auto relu_int = ReluIntDefault();
    auto relu_float = ReluFloatDefault();
    auto mp = MaxPoolingDefault(2);
    auto mp_float = MaxPoolingFloatDefault(2);

    // Load param
    conv1.InitWeightFromFile(model_path + "/conv1");
    bn1.InitParamWithFile(model_path + "/bn1");
    conv2.InitWeightFromFile(model_path + "/conv2");
    bn2.InitParamWithFile(model_path + "/bn2");
    conv3.InitWeightFromFile(model_path + "/conv3");
    bn3.InitParamWithFile(model_path + "/bn3");
    conv4.InitWeightFromFile(model_path + "/conv4");
    bn4.InitParamWithFile(model_path + "/bn4");
    conv5.InitWeightFromFile(model_path + "/conv5");
    bn5.InitParamWithFile(model_path + "/bn5");
    conv6.InitWeightFromFile(model_path + "/conv6");
    bn6.InitParamWithFile(model_path + "/bn6");
    clf.InitWeightFromFile(model_path + "/classifier");

    // Prepare PackMatMul
    conv2.InitPackMatMul();
    conv3.InitPackMatMul();
    conv4.InitPackMatMul();
    conv5.InitPackMatMul();
    conv6.InitPackMatMul();

    // Complete forward 
    //conv1, relu_float, bn1, q1, conv2, relu_int, dq2, bn2, q2, mp, 
    //    conv3, relu_int, dq3, bn3, q3, conv4, relu_int, dq4, bn4, q4, mp,
    //    conv5, relu_int, dq5, bn5, q5, conv6, relu_int, dq6, bn6, q6, mp
    
    auto conv1_res = relu_float.Forward(conv1.Forward(img_mat));
    auto bn1_res = bn1.Forward(conv1_res);
    auto q1_res = q1.Forward(bn1_res); 
    
    std::cout << "block 1 done!\n";
    std::cout << q1_res.ConstMap().rows() << "\n";
    std::cout << q1_res.ConstMap().cols() << "\n";

    auto conv2_res = conv2.Forward(q1_res);
    auto conv2_res_relu = relu_int.Forward(conv2_res);
    auto dq2_res = dq2.Forward(conv2_res_relu);
    auto bn2_res = bn2.Forward(dq2_res);
    auto q2_res = q2.Forward(bn2_res);
    auto mp2_res = mp.Forward(q2_res);

    std::cout << "block 2 done!\n";
    std::cout << mp2_res.ConstMap().rows() << "\n";
    std::cout << mp2_res.ConstMap().cols() << "\n";
    

    auto conv3_res = conv3.Forward(mp2_res);
    auto conv3_res_relu = relu_int.Forward(conv3_res);
    auto dq3_res = dq3.Forward(conv3_res_relu);
    auto bn3_res = bn3.Forward(dq3_res);
    auto q3_res = q3.Forward(bn3_res);

    std::cout << "block 3 done!\n";
    std::cout << q3_res.ConstMap().rows() << "\n";
    std::cout << q3_res.ConstMap().cols() << "\n";

    auto conv4_res = conv4.Forward(q3_res);
    auto conv4_res_relu = relu_int.Forward(conv4_res);
    auto dq4_res = dq4.Forward(conv4_res_relu);
    auto bn4_res = bn4.Forward(dq4_res);
    auto q4_res = q4.Forward(bn4_res);
    auto mp4_res = mp.Forward(q4_res);

    std::cout << "block 4 done!\n";
    std::cout << mp4_res.ConstMap().rows() << "\n";
    std::cout << mp4_res.ConstMap().cols() << "\n";

    auto conv5_res = conv5.Forward(mp4_res);
    auto conv5_res_relu = relu_int.Forward(conv5_res);
    auto dq5_res = dq5.Forward(conv5_res_relu);
    auto bn5_res = bn5.Forward(dq5_res);
    auto q5_res = q5.Forward(bn5_res);

    std::cout << "block 5 done!\n";
    std::cout << q5_res.ConstMap().rows() << "\n";
    std::cout << q5_res.ConstMap().cols() << "\n";

    auto conv6_res = conv6.Forward(q5_res);
    auto conv6_res_relu = relu_int.Forward(conv6_res);
    auto dq6_res = dq6.Forward(conv6_res_relu);
    auto bn6_res = bn6.Forward(dq6_res);
    auto mp6_res = mp_float.Forward(bn6_res);

    std::cout << "block 6 done!\n";

    std::cout << conv6_res.ConstMap().rows() << "\n";
    std::cout << conv6_res.ConstMap().cols() << "\n";

    std::cout << mp6_res.ConstMap().rows() << "\n";
    std::cout << mp6_res.ConstMap().cols() << "\n";
    
    std::cout << "final flatten dim\n";
    std::cout << mp6_res.Flatten().ConstMap().rows() << "\n";
    std::cout << mp6_res.Flatten().ConstMap().cols() << "\n";
    auto res_fin = clf.Forward(mp6_res.Flatten());

    std::cout << "clf done!\n";
    /*
    auto conv1_res = conv1.Forward(img_mat); 
    auto conv1_res_relu = relu.Forward(conv1_res);
    auto bn1_res = bn1_res.forward(conv1_res_relu);
    auto q1_res = q1.Forward(bn1_res);
    */
}

int main(int argc, char** argv) {
    std::cout.precision(5);
    std::string model_path;
    SetLayerGlobalActivationBit(kDefaultActivationBit);

    for (int i = 1; i < argc; i+= 2) {
        if (std::string(argv[i]) == "--debug") {
            arg_debug_mode = true;
            i -= 1;
        }
        if (std::string(argv[i]) == "--model_path") {
           model_path = std::string(argv[i+1]); 
        }
    }

    const int img_num = 1;
    VGG(img_num, model_path);
    return 0;
}
