#ifndef LAYER_H_
#define LAYER_H_

#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdint>
#include <cstdlib>
#include <iostream>
#include <limits>
#include <random>
#include <vector>
#include <cstring>
#include <cstdlib>
#include <ctime>
#include "../gemmlowp/public/gemmlowp.h"
#include "matrix_storage.h"
#include "matrix_multiplication.h"
#include "utils.h"

namespace {
    int global_activation_bit = 7;
}

void SetLayerGlobalActivationBit(int bit) {global_activation_bit = bit;}

// Abstraction of layer
template <typename tScalarIn, typename tScalarPackIn, typename tScalarOut, typename tScalarPackOut, gemmlowp::MapOrder tOrder>
class Layer {
  public:
    Layer() = default;
    using ImgMatrixIn = PackImageStorage<tScalarIn, tScalarPackIn, tOrder>;
    using ImgMatrixOut = PackImageStorage<tScalarOut, tScalarPackOut, tOrder>;

    virtual const ImgMatrixOut& Forward(const ImgMatrixIn& input) = 0;
    virtual const ImgMatrixOut& GetResultImgMatrix() = 0;
};

// This is for layer that outputs PackMatrixWithStorage instead of PackImageStorage
template <typename tScalarIn, typename tScalarPackIn, typename tScalarOut, typename tScalarPackOut, gemmlowp::MapOrder tOrder>
class BaseLayer {
  public:
    BaseLayer() = default;
    using BaseMatrixIn = PackMatrixWithStorage<tScalarIn, tScalarPackIn, gemmlowp::MapOrder::ColMajor>;
    using BaseMatrixOut = PackMatrixWithStorage<tScalarOut, tScalarPackOut, tOrder>;

    // Two different output types. At lesat  
    virtual const BaseMatrixOut& Forward(const BaseMatrixIn& input) = 0;
    virtual const BaseMatrixOut& GetResultMatrix() = 0;
};




template <typename tScalarIn, typename tScalarPackIn, typename tScalarOut, typename tScalarPackOut, gemmlowp::MapOrder tOrder>
class Conv2D : public Layer<tScalarIn, tScalarPackIn, tScalarOut, tScalarPackOut, tOrder> {
  public:
    using ImgMatrixIn = PackImageStorage<tScalarIn, tScalarPackIn, tOrder>;
    using ImgMatrixOut = PackImageStorage<tScalarOut, tScalarPackOut, tOrder>;
    
    //Conv2D() = delete;
    // Default conv2d param
    Conv2D(int kernel_num, int activation_bit): Conv2D(3, 3, kernel_num, 1, 2, activation_bit) {} 
    // Use global activation_bit 
    Conv2D(int kernel_size, int ch_num, int kernel_num, int padding, int stride): 
        Conv2D(kernel_size, ch_num, kernel_num, padding, stride, global_activation_bit) {}

    // Intiailize the weight matrix because we already know its shape
    Conv2D(int kernel_size, int ch_num, int kernel_num, int padding, int stride, int activation_bit)
        : kernel_size(kernel_size), ch_num(ch_num), kernel_num(kernel_num), padding(padding), stride(stride), 
                activation_bit(activation_bit) {
            int col_num = kernel_size * kernel_size * ch_num;
            int pack_ratio = sizeof(tScalarPackOut)/sizeof(tScalarOut);
            col_num = col_num - (col_num % pack_ratio) + (col_num % pack_ratio > 0) * pack_ratio;
            weight_mat = PackMatrixWithStorage<tScalarOut, tScalarPackOut, tOrder>(kernel_num, col_num, activation_bit);
    }
    
    void InitRandomWeight() {
        weight_mat.MakeTrinaryRandom(); 
    }
    void InitWeightWithMatrix(PackMatrixWithStorage<tScalarOut, tScalarPackOut, tOrder> weight);

    // <path> is a directory that looks like
    // <path>/
    //     weight.int (or weight.float)
    void InitWeightFromFile(std::string path) {
        weight_mat = PackMatrixWithStorage<tScalarOut, tScalarPackOut, tOrder>(path + "/weight", activation_bit);
        int op_pack_ratio = weight_mat.GetOpPackRatio();
        int target_col = ch_num * kernel_size * kernel_size;
        target_col += (target_col % op_pack_ratio) * (op_pack_ratio - (target_col % op_pack_ratio));
        assert(weight_mat.ConstMap().rows() == kernel_num);
        assert(weight_mat.ConstMap().cols() == target_col);
    }

    // If tScalarIn or tScalarOut is float calling this function won't compile
    // which is intended
    void InitPackMatMul() {
        weight_mat.PrepareAuxMatrixMap();
    }


    const ImgMatrixOut& Forward(const ImgMatrixIn& input) override {
        int img_row = input.GetImgRows(); 
        int img_col = input.GetImgCols();
        int img_ch_num = input.GetChNum();
        int img_num = input.GetImgNum();
        assert(img_ch_num == ch_num);
        

        // Initializing output img_matrix and img2col intermediate matrix; 
        int im2col_rows = kernel_size * kernel_size * img_ch_num;
        im2col_rows = im2col_rows - (im2col_rows % input.GetOpPackRatio()) + 
            (im2col_rows % input.GetOpPackRatio() > 0) * input.GetOpPackRatio();

        int row_patch_num = (img_row + 2 * padding - kernel_size) / stride + 1;
        int col_patch_num = (img_col + 2 * padding - kernel_size) / stride + 1;
        int im2col_cols = row_patch_num * col_patch_num * img_num;

        InitIm2ColMat(im2col_rows, im2col_cols);
        InitOutput(row_patch_num, col_patch_num, kernel_num, img_num);

        // foward by matmul
        input.Img2Col(padding, stride, kernel_size, &im2col_res_mat);
        clock_t mm_start = clock();
        MatMulForward(weight_mat, im2col_res_mat, &result_img_mat);
        std::cout << "mm part of this forward takes " << (float)(clock() - mm_start)/CLOCKS_PER_SEC << "\n"; 
        return result_img_mat;
    }
    const ImgMatrixOut& GetResultImgMatrix() override{ return result_img_mat;} 
    const PackMatrixWithStorage<tScalarOut, tScalarPackOut, tOrder>& GetWeightMatrix() {return weight_mat; }
    const PackMatrixWithStorage<tScalarOut, tScalarPackOut, gemmlowp::MapOrder::ColMajor>& GetImg2ColMatrix() {
        return im2col_res_mat;
    }

  protected:
    void InitOutput(int img_row, int img_col, int ch_num, int img_num) {
        result_img_mat = ImgMatrixOut(img_row, img_col, ch_num, img_num, activation_bit);
    }
    void InitIm2ColMat(int rows, int cols) {
        im2col_res_mat = PackMatrixWithStorage<tScalarOut, tScalarPackOut, gemmlowp::MapOrder::ColMajor>(rows, cols, activation_bit);
    }

  private:
    int kernel_size, ch_num, kernel_num, padding, stride, activation_bit;
    PackMatrixWithStorage<tScalarOut, tScalarPackOut, tOrder> weight_mat;
    PackMatrixWithStorage<tScalarOut, tScalarPackOut, gemmlowp::MapOrder::ColMajor> im2col_res_mat;
    ImgMatrixOut result_img_mat;
};

template <typename tScalarIn, typename tScalarPackIn, typename tScalarOut, typename tScalarPackOut, gemmlowp::MapOrder tOrder>
class BatchNorm: public Layer<tScalarIn, tScalarPackIn, tScalarOut, tScalarPackOut, tOrder> {
  public:
    BatchNorm(int bn_feature_dim) : BatchNorm(bn_feature_dim, global_activation_bit) {}
    BatchNorm(int bn_feature_dim, int activation_bit) : 
        bn_feature_dim(bn_feature_dim), activation_bit(activation_bit),
        mean_mat(1, bn_feature_dim, activation_bit),
        std_mat(1, bn_feature_dim, activation_bit),
        weight_mat(1, bn_feature_dim, activation_bit),
        bias_mat(1, bn_feature_dim, activation_bit) {}
    using ImgMatrixIn = PackImageStorage<tScalarIn, tScalarPackIn, tOrder>;
    using ImgMatrixOut = PackImageStorage<tScalarOut, tScalarPackOut, tOrder>;
    using ParamMatrix = PackMatrixWithStorage<tScalarOut, tScalarPackOut, tOrder>;
    
    // Formuala: weight * (x - mean) / sigma + bias  
    const ImgMatrixOut& Forward(const ImgMatrixIn& input) override {
        int img_row = input.GetImgRows(); 
        int img_col = input.GetImgCols(); 
        int ch_num = input.GetChNum(); 
        int img_num =  input.GetImgNum();
        InitOutput(input.GetImgRows(), input.GetImgCols(), input.GetChNum(), input.GetImgNum());
        int img_stride = input.GetImgRows() * input.GetImgCols();

        const auto& input_mat_map = input.ConstMap();
        const auto& bias_mat_map = bias_mat.ConstMap();
        const auto& weight_mat_map = weight_mat.ConstMap();
        const auto& mean_mat_map = mean_mat.ConstMap();
        const auto& std_mat_map = std_mat.ConstMap();

        auto output_mat_map = result_img_matrix.Map();

        for (int i = 0; i < input_mat_map.rows(); i++) {
            auto weight = weight_mat_map(0, i);
            auto mean = mean_mat_map(0, i);
            auto std = std_mat_map(0, i);
            auto bias = bias_mat_map(0, i);
            for (int j = 0; j < input_mat_map.cols(); j++) {
                int param_j = j % img_stride;
                // Temporarily use scale = 1.0 and shift = 0
                tScalarOut val = input_mat_map(i,j);
                output_mat_map(i,j) = weight * (val - mean) / std + bias;
            }
        }
        return result_img_matrix;
    }
    const ImgMatrixOut& GetResultImgMatrix() override{ return result_img_matrix;} 

    // Functions to intiailize the parameter matrices 
    void InitRandomParam(const ImgMatrixIn& input) {
        bias_mat.MakeRandom();
        weight_mat.MakeRandom();
        mean_mat.MakeRandom();
        std_mat.MakeRandomPositive();
    }
    void InitParamWithMatrix(const ParamMatrix& beta, const ParamMatrix& gamma, 
            const ParamMatrix& mean, const ParamMatrix& std);

    // <path> is a directory that looks like
    // <path>/
    //     running_mean.<type>
    //     running_std.<type>
    //     weight.<type>
    //     bias.<type>
    void InitParamWithFile(std::string path) {
        bias_mat = ParamMatrix(path + "/bias", activation_bit);
        weight_mat = ParamMatrix(path + "/weight", activation_bit);
        std_mat = ParamMatrix(path + "/running_std", activation_bit);
        mean_mat = ParamMatrix(path + "/running_std", activation_bit);

        assert(bias_mat.ConstMap().rows() == 1);
        assert(bias_mat.ConstMap().cols() == bn_feature_dim);
        assert(weight_mat.ConstMap().rows() == 1);
        assert(weight_mat.ConstMap().cols() == bn_feature_dim);
        assert(mean_mat.ConstMap().rows() == 1);
        assert(mean_mat.ConstMap().cols() == bn_feature_dim);
        assert(std_mat.ConstMap().rows() == 1);
        assert(std_mat.ConstMap().cols() == bn_feature_dim);
    }
  

    const ParamMatrix& GetBiasMatrix() {return bias_mat;}
    const ParamMatrix& GetWeightMatrix() {return weight_mat;}
    const ParamMatrix& GetMeanMatrix() {return mean_mat;}
    const ParamMatrix& GetStdMatrix() {return std_mat;}

  protected: 
    void InitOutput(int img_row, int img_col, int ch_num, int img_num) {
        result_img_matrix = ImgMatrixOut(img_row, img_col, ch_num, img_num, activation_bit);
    }

  private: 
    // Param for batch norm
    int bn_feature_dim;
    int activation_bit;
    ParamMatrix bias_mat, weight_mat, mean_mat, std_mat;
    ImgMatrixOut result_img_matrix;
};

template <typename tScalarIn, typename tScalarPackIn, typename tScalarOut, typename tScalarPackOut, gemmlowp::MapOrder tOrder>
class Relu: public Layer<tScalarIn, tScalarPackIn, tScalarOut, tScalarPackOut, tOrder>{
  public:
    Relu(): Relu(global_activation_bit) {}
    Relu(int activation_bit): activation_bit(activation_bit) {}
    using ImgMatrixIn = PackImageStorage<tScalarIn, tScalarPackIn, tOrder>;
    using ImgMatrixOut = PackImageStorage<tScalarOut, tScalarPackOut, tOrder>;
    
    const ImgMatrixOut& Forward(const ImgMatrixIn& input) override {
        InitOutput(input.GetImgRows(), input.GetImgCols(), input.GetChNum(), input.GetImgNum());
        assert(input.Storage().size() == result_img_matrix.Storage().size());

        const auto& input_storage = input.Storage();
        auto& result_mat_storage = result_img_matrix.Storage();
        for (int i = 0; i < input.Storage().size(); i++) {
            auto val = input_storage[i];
            val = std::max(static_cast<tScalarIn>(0), val);
            result_mat_storage[i] = val;
        }
        return result_img_matrix;
    }
    const ImgMatrixOut& GetResultImgMatrix() override{ return result_img_matrix;} 

  protected: 
    void InitOutput(int img_row, int img_col, int ch_num, int img_num) {
        result_img_matrix = ImgMatrixOut(img_row, img_col, ch_num, img_num, activation_bit);
    }

  private: 
    int activation_bit;
    ImgMatrixOut result_img_matrix;
  
};

template <typename tScalarIn, typename tScalarPackIn, typename tScalarOut, typename tScalarPackOut, gemmlowp::MapOrder tOrder>
class MaxPooling: public Layer<tScalarIn, tScalarPackIn, tScalarOut, tScalarPackOut, tOrder>{
  public:
    MaxPooling(int stride): MaxPooling(stride, global_activation_bit) {}
    MaxPooling(int stride, int activation_bit): stride(stride), activation_bit(activation_bit) {}

    using ImgMatrixIn = PackImageStorage<tScalarIn, tScalarPackIn, tOrder>;
    using ImgMatrixOut = PackImageStorage<tScalarOut, tScalarPackOut, tOrder>;
    
    const ImgMatrixOut& Forward(const ImgMatrixIn& input) override {
        int img_row = input.GetImgRows(); 
        int img_col = input.GetImgCols(); 
        int ch_num = input.GetChNum(); 
        int img_num =  input.GetImgNum();
        int img_stride = img_row * img_col;

        int out_img_row = img_row / stride;
        int out_img_col = img_col / stride;
        //if (img_row % stride != 0) out_img_row += 1;
        //if (img_col % stride != 0) out_img_col += 1;
        int out_img_stride = out_img_row * out_img_col;
    
        InitOutput(out_img_row, out_img_col, ch_num, img_num);
        
        const auto& input_mat_map = input.ConstMap();
        auto output_mat_map = result_img_matrix.Map(); 

        for (int c = 0; c < input_mat_map.rows(); c++) {
            for (int k = 0; k < input_mat_map.cols(); k++) {
                int img_count = k / img_stride, cur_img_k = k % img_stride;
                int cur_img_i = cur_img_k / img_col, cur_img_j = cur_img_k % img_col; 

                int cur_out_img_i = cur_img_i / stride;
                int cur_out_img_j = cur_img_j / stride;
                int cur_out_img_k = cur_out_img_i * out_img_col + cur_out_img_j;
                int out_k = img_count * out_img_stride + cur_out_img_k;
                output_mat_map(c, out_k) = std::max(output_mat_map(c, out_k), input_mat_map(c, k));
                // std::cout << c << " " << cur_out_img_i << " " << cur_out_img_j << " " << out_k << " " << (int)output_mat_map(c, out_k) << "\n";
            }
        }
        return result_img_matrix;

    }
    const ImgMatrixOut& GetResultImgMatrix() override{ return result_img_matrix;} 
  
  protected:
    void InitOutput(int img_row, int img_col, int ch_num, int img_num) {
        result_img_matrix = ImgMatrixOut(img_row, img_col, ch_num, img_num, 
                activation_bit, std::numeric_limits<tScalarOut>::min());
    }

  private:
    int stride, activation_bit;
    ImgMatrixOut result_img_matrix;
};

template <typename tScalarIn, typename tScalarPackIn, typename tScalarOut, typename tScalarPackOut, gemmlowp::MapOrder tOrder>
class Quantization: public Layer<tScalarIn, tScalarPackIn, tScalarOut, tScalarPackOut, tOrder>{
  public:
    Quantization(float scale, int shift): Quantization(scale, shift, global_activation_bit) {}
    Quantization(float scale, int shift, int activation_bit): scale(scale), shift(shift), activation_bit(activation_bit) {}
    using ImgMatrixIn = PackImageStorage<tScalarIn, tScalarPackIn, tOrder>;
    using ImgMatrixOut = PackImageStorage<tScalarOut, tScalarPackOut, tOrder>;
    
    const ImgMatrixOut& Forward(const ImgMatrixIn& input) override {
        InitOutput(input.GetImgRows(), input.GetImgCols(), input.GetChNum(), input.GetImgNum());
        assert(input.Storage().size() == result_img_matrix.Storage().size());
        bool in_is_int = std::numeric_limits<tScalarIn>::is_integer;
        bool out_is_int = std::numeric_limits<tScalarOut>::is_integer;
        assert(!in_is_int && out_is_int);

        const auto& input_storage = input.Storage();
        auto& result_mat_storage = result_img_matrix.Storage();
        for (int i = 0; i < input.Storage().size(); i++) {
            auto val = input_storage[i];
            result_mat_storage[i] = Quantize(val, static_cast<tScalarIn>(scale), 
                static_cast<tScalarOut>(shift), activation_bit);
        }
        return result_img_matrix;
    }
    const ImgMatrixOut& GetResultImgMatrix() override{ return result_img_matrix;} 

  protected: 
    void InitOutput(int img_row, int img_col, int ch_num, int img_num) {
        result_img_matrix = ImgMatrixOut(img_row, img_col, ch_num, img_num, activation_bit);
    }

  private: 
    float scale;
    int activation_bit, shift;
    ImgMatrixOut result_img_matrix;
  
};

template <typename tScalarIn, typename tScalarPackIn, typename tScalarOut, typename tScalarPackOut, gemmlowp::MapOrder tOrder>
class DeQuantization: public Layer<tScalarIn, tScalarPackIn, tScalarOut, tScalarPackOut, tOrder>{
  public:
    DeQuantization(float scale, int shift): DeQuantization(scale, shift, global_activation_bit) {}
    DeQuantization(float scale, int shift, int activation_bit): scale(scale), shift(shift), activation_bit(activation_bit) {}
    using ImgMatrixIn = PackImageStorage<tScalarIn, tScalarPackIn, tOrder>;
    using ImgMatrixOut = PackImageStorage<tScalarOut, tScalarPackOut, tOrder>;
    
    const ImgMatrixOut& Forward(const ImgMatrixIn& input) override {
        InitOutput(input.GetImgRows(), input.GetImgCols(), input.GetChNum(), input.GetImgNum());
        assert(input.Storage().size() == result_img_matrix.Storage().size());
        bool in_is_int = std::numeric_limits<tScalarIn>::is_integer;
        bool out_is_int = std::numeric_limits<tScalarOut>::is_integer;
        assert(in_is_int && !out_is_int);

        const auto& input_storage = input.Storage();
        auto& result_mat_storage = result_img_matrix.Storage();
        for (int i = 0; i < input.Storage().size(); i++) {
            auto val = input_storage[i];
            result_mat_storage[i] = DeQuantize(val, static_cast<tScalarOut>(scale), static_cast<tScalarIn>(shift));
        }
        return result_img_matrix;
    }
    const ImgMatrixOut& GetResultImgMatrix() override{ return result_img_matrix;} 

  protected: 
    void InitOutput(int img_row, int img_col, int ch_num, int img_num) {
        result_img_matrix = ImgMatrixOut(img_row, img_col, ch_num, img_num, activation_bit);
    }

  private: 
    float scale;
    int activation_bit, shift;
    ImgMatrixOut result_img_matrix;
  
};


template <typename tScalar, typename tScalarPack, gemmlowp::MapOrder tOrder>
class Linear: public BaseLayer<tScalar, tScalarPack, tScalar, tScalarPack, tOrder>{
  public: 
    Linear(int out_dim, int in_dim): Linear(out_dim, in_dim, global_activation_bit) {}
    Linear(int out_dim, int in_dim, int activation_bit): out_dim(out_dim), in_dim(in_dim), activation_bit(activation_bit),
        weight_mat(out_dim, in_dim, activation_bit) {}
    using BaseMatrix = PackMatrixWithStorage<tScalar, tScalarPack, tOrder>;
    using BaseColMatrix = PackMatrixWithStorage<tScalar, tScalarPack, gemmlowp::MapOrder::ColMajor>;
    using ImgMatrixIn = PackImageStorage<tScalar, tScalarPack, tOrder>;
    
    const BaseMatrix& Forward(const BaseColMatrix& input) override{
        InitOutput(out_dim, input.ConstMap().cols());
        MatMulForward(weight_mat, input, &result_mat);
        return result_mat;
    }
    
    const BaseMatrix& GetResultMatrix() override {return result_mat;} 
    
    void InitRandomWeight() {
        weight_mat.MakeRandom();
    }
    // <path> is a directory that looks like
    // <path>/
    //     weight.<dtype>
    void InitWeightFromFile(std::string path) {
        weight_mat = BaseMatrix(path + "/weight", activation_bit);
        assert(weight_mat.ConstMap().rows() == out_dim);
        assert(weight_mat.ConstMap().cols() == in_dim);
    }

  protected:
    void InitOutput(int row, int col) {
        result_mat = BaseMatrix(row, col, activation_bit);
    } 


  private:
    BaseMatrix weight_mat, result_mat;
    int activation_bit, in_dim, out_dim;
};

// Default layer
using Conv2DDefault = Conv2D<int8_t, int32_t, int8_t, int32_t, gemmlowp::MapOrder::RowMajor>;
using Conv2DFloatDefault = Conv2D<float, float, float, float, gemmlowp::MapOrder::RowMajor>;
using BNDefault = BatchNorm<float, float, float, float, gemmlowp::MapOrder::RowMajor>;
using ReluFloatDefault = Relu<float, float, float, float, gemmlowp::MapOrder::RowMajor>;
using ReluIntDefault = Relu<int8_t, int32_t, int8_t, int32_t, gemmlowp::MapOrder::RowMajor>;
using MaxPoolingDefault = MaxPooling<int8_t, int32_t, int8_t, int32_t, gemmlowp::MapOrder::RowMajor>;
using MaxPoolingFloatDefault = MaxPooling<float, float, float, float, gemmlowp::MapOrder::RowMajor>;
using QuantDefault = Quantization<float, float, int8_t, int32_t, gemmlowp::MapOrder::RowMajor>;
using DeQuantDefault = DeQuantization<int8_t, int32_t, float, float, gemmlowp::MapOrder::RowMajor>;
using LinearDefault = Linear<float, float, gemmlowp::MapOrder::RowMajor>;
#endif
