#ifndef MATRIX_STORAGE_H_
#define MATRIX_STORAGE_H_

#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdint>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <limits>
#include <random>
#include <vector>
#include <sstream>
#include "../gemmlowp/public/gemmlowp.h" 

enum class MapOption {
    NORMAL,
    B1REP,
    B2REP,
    B2FIRST    
};


template <typename tScalar, typename tScalarPack, gemmlowp::MapOrder tOrder>
class PackMatrixWithStorage {
 public:
    PackMatrixWithStorage() = default;
    PackMatrixWithStorage(int rows, int cols, int activation_bit)
            : PackMatrixWithStorage(rows, cols, activation_bit, 0) {}
    PackMatrixWithStorage(int rows, int cols, int activation_bit, tScalar default_value)
            : storage(rows * cols, default_value), matrix_map(storage.data(), rows, cols), activation_bit(activation_bit) {}
    
    PackMatrixWithStorage(std::string file_path, int activation_bit): activation_bit(activation_bit) {
        // Currently only support reading to row major matrix
        bool is_integer = std::numeric_limits<tScalar>::is_integer;
        assert(tOrder == gemmlowp::MapOrder::RowMajor);
        std::ifstream fs;
        if (is_integer) {
            fs.open(file_path + ".int");
        }
        else {
            fs.open(file_path + ".float");
        }
        assert(fs.is_open());
        
        int row_count = 0, col_count = 0;
        for (std::string line; std::getline(fs, line); row_count++) {
            col_count = 0;
            std::stringstream ls(line);
            for (std::string val_str; std::getline(ls, val_str, ','); col_count++) {
                storage.push_back(static_cast<tScalar>(std::stof(val_str)));
            }
            if (col_count % op_pack_ratio != 0) {
                int diff = op_pack_ratio - col_count % op_pack_ratio; 
                for (int i = 0; i < diff; i++) {
                    storage.push_back(static_cast<tScalar>(0));
                }
                col_count += diff;
            }
        }
        matrix_map = gemmlowp::MatrixMap<tScalar, tOrder>(storage.data(), row_count, col_count); 

    }

    void MakeRandom() {
        std::random_device rd;
        static std::mt19937 random_engine(rd());
        std::uniform_real_distribution<float> distribution(-1, 1);
        for (auto& x : storage) {
            x = static_cast<tScalar>(distribution(random_engine));
        }
    }

    void MakeRandomLarge() {
        static std::mt19937 random_engine;
        std::uniform_real_distribution<float> distribution(-10, 10);
        for (auto& x : storage) {
            x = static_cast<tScalar>(distribution(random_engine));
        }
    }

    void MakeRandomPositive() {
        static std::mt19937 random_engine;
        std::uniform_real_distribution<float> distribution(0.0001, 1);
        for (auto& x : storage) {
            x = static_cast<tScalar>(distribution(random_engine));
        }
    }

    void MakeInt8Random() {
        MakeDiscreteIntRandom(-(2 << (activation_bit - 3) ) + 1, (2 << (activation_bit - 3)) - 1, false);
    }
    void MakeUint8Random() {
        MakeDiscreteIntRandom(0, 255, false);
    }
    void MakeTrinaryRandom() {
        MakeDiscreteIntRandom(-1, 1, false);
    }
    void MakeBinaryRandom() {
        MakeBinaryRandom(0, 1, false);
    }

    void FillValueWithIndex() {
        int count = 0;
        for (auto& x: storage) {
            x = count;
            count++;
        }
    }

    void ApplyActivationMask() {
        tScalar activation_bit_mask = GetActivationBitMask();
        for (auto x: storage) {
            x = x & activation_bit_mask;
        }
    }
    
    void PrepareAuxMatrixMap() {
        // Only do this on lhs matrix
        assert(matrix_map.kOrder == gemmlowp::MapOrder::RowMajor);

        tScalar activation_bit_mask = GetActivationBitMask();
        int rows = matrix_map.rows();
        int cols = matrix_map.cols();

        storage_b1_rep.resize(storage.size());
        storage_b2_rep.resize(storage.size());
        storage_b2_first.resize(storage.size());
        storage_b2_first_sum.resize(rows); 
        
        tScalar cur_b2_first_sum = 0;
        for (int i = 0; i < storage.size(); i++) {

            if (i%cols == 0) {
                int indx = i/cols - 1;
                if (indx >= 0) {
                    storage_b2_first_sum[indx] = cur_b2_first_sum;
                }                
                cur_b2_first_sum = 0;
            }

            bool b1_is_1 = ((storage[i] & 1) > 0); 
            bool b2_is_1 = ((storage[i] & 2) > 0); 
            tScalar first_bit_2 = 1;
            tScalar all_0 = 0;
            tScalar all_1 = (~all_0) & activation_bit_mask;

            storage_b1_rep[i] = b1_is_1 ? all_1 : all_0;
            storage_b2_rep[i] = b2_is_1 ? all_1 : all_0;
            storage_b2_first[i] = b2_is_1 ? first_bit_2 : all_0;
            
            cur_b2_first_sum += storage_b2_first[i];
            
        }
        storage_b2_first_sum[storage_b2_first_sum.size()-1] = cur_b2_first_sum;
        
    }

    // The following offers differnt way to get MatrixMap 
    gemmlowp::MatrixMap<const tScalar, tOrder> ConstMap() const {
        return ConstMap(MapOption::NORMAL);
    }
    gemmlowp::MatrixMap<tScalar, tOrder> Map() {
        return Map(MapOption::NORMAL);
    }

    gemmlowp::MatrixMap<const tScalar, tOrder> ConstMap(MapOption opt) const {
        switch(opt) {
            case MapOption::NORMAL:
                return gemmlowp::MatrixMap<const tScalar, tOrder>(
                    storage.data(), matrix_map.rows(), matrix_map.cols());
            case MapOption::B1REP:
                return gemmlowp::MatrixMap<const tScalar, tOrder>(
                    storage_b1_rep.data(), matrix_map.rows(), matrix_map.cols());
            case MapOption::B2REP:
                return gemmlowp::MatrixMap<const tScalar, tOrder>(
                    storage_b2_rep.data(), matrix_map.rows(), matrix_map.cols());
            case MapOption::B2FIRST:
                return gemmlowp::MatrixMap<const tScalar, tOrder>(
                    storage_b2_first.data(), matrix_map.rows(), matrix_map.cols());
            default:
                assert(false);
        }
    }

    gemmlowp::MatrixMap<tScalar, tOrder> Map(MapOption opt) {
        switch(opt) {
            case MapOption::NORMAL:
                return gemmlowp::MatrixMap<tScalar, tOrder>(
                    storage.data(), matrix_map.rows(), matrix_map.cols());
            case MapOption::B1REP:
                return gemmlowp::MatrixMap<tScalar, tOrder>(
                    storage_b1_rep.data(), matrix_map.rows(), matrix_map.cols());
            case MapOption::B2REP:
                return gemmlowp::MatrixMap<tScalar, tOrder>(
                    storage_b2_rep.data(), matrix_map.rows(), matrix_map.cols());
            case MapOption::B2FIRST:
                return gemmlowp::MatrixMap<tScalar, tOrder>(
                    storage_b2_first.data(), matrix_map.rows(), matrix_map.cols());
            default:
                assert(false);
        }
    }
    
    
    gemmlowp::MatrixMap<const tScalarPack, tOrder> ConstMapPack() const {
        return ConstMapPack(MapOption::NORMAL);
    }


    gemmlowp::MatrixMap<const tScalarPack, tOrder> ConstMapPack(MapOption opt) const {
        int rows = matrix_map.rows();
        int cols = matrix_map.cols();
        if (matrix_map.kOrder == gemmlowp::MapOrder::RowMajor)
            cols = cols / op_pack_ratio; 
        if (matrix_map.kOrder == gemmlowp::MapOrder::ColMajor)
            rows = rows / op_pack_ratio; 

        switch(opt) {
            case MapOption::NORMAL:
                return gemmlowp::MatrixMap<const tScalarPack, tOrder>(
                    reinterpret_cast<const tScalarPack*>(storage.data()), rows, cols);
            case MapOption::B1REP:
                return gemmlowp::MatrixMap<const tScalarPack, tOrder>(
                    reinterpret_cast<const tScalarPack*>(storage_b1_rep.data()), rows, cols);
            case MapOption::B2REP:
                return gemmlowp::MatrixMap<const tScalarPack, tOrder>(
                    reinterpret_cast<const tScalarPack*>(storage_b2_rep.data()), rows, cols);
            case MapOption::B2FIRST:
                return gemmlowp::MatrixMap<const tScalarPack, tOrder>(
                    reinterpret_cast<const tScalarPack*>(storage_b2_first.data()), rows, cols);
            default:
                assert(true);
        }
    }

    gemmlowp::MatrixMap<tScalarPack, tOrder> MapPack() {
        return MapPack(MapOption::NORMAL);
    }

    gemmlowp::MatrixMap<tScalarPack, tOrder> MapPack(MapOption opt) {
        int rows = matrix_map.rows();
        int cols = matrix_map.cols();
        if (matrix_map.kOrder == gemmlowp::MapOrder::RowMajor)
            cols = cols / op_pack_ratio; 
        if (matrix_map.kOrder == gemmlowp::MapOrder::ColMajor)
            rows = rows / op_pack_ratio; 

        switch(opt) {
            case MapOption::NORMAL:
                return gemmlowp::MatrixMap<tScalarPack, tOrder>(
                    reinterpret_cast<tScalarPack*>(storage.data()), rows, cols);
            case MapOption::B1REP:
                return gemmlowp::MatrixMap<tScalarPack, tOrder>(
                    reinterpret_cast<tScalarPack*>(storage_b1_rep.data()), rows, cols);
            case MapOption::B2REP:
                return gemmlowp::MatrixMap<tScalarPack, tOrder>(
                    reinterpret_cast<tScalarPack*>(storage_b2_rep.data()), rows, cols);
            case MapOption::B2FIRST:
                return gemmlowp::MatrixMap<tScalarPack, tOrder>(
                    reinterpret_cast<tScalarPack*>(storage_b2_first.data()), rows, cols);
            default:
                assert(true);
        }
    }

    

    const std::vector<tScalar>& Storage() const { return storage; }
    std::vector<tScalar>& Storage() { return storage; }
    int GetActivationBit() const {return activation_bit; } 
    tScalar GetActivationBitMask() {
        tScalar activation_bit_mask = 0;
        tScalar one = 1;
        for (int i = 0; i < activation_bit; i++) 
            activation_bit_mask += one << i;
        return activation_bit_mask;
    }
    tScalarPack GetActivationBitMaskPack() {
        tScalarPack activation_bit_mask = static_cast<tScalar>(GetActivationBitMask());
        tScalarPack activation_bit_mask_pack = 0;
        for (int i = 0; i < op_pack_ratio; i++) {
            activation_bit_mask_pack += (activation_bit_mask << (i*8));
        } 
        return activation_bit_mask_pack;
    }
    const std::vector<tScalar>& GetB2FirstSum() const {return storage_b2_first_sum; }
    int GetOpPackRatio() const {return op_pack_ratio; }

 protected:
    void MakeDiscreteIntRandom(int min, int max, bool mask_by_activation_bit) {
        tScalar activation_bit_mask  = GetActivationBitMask();

        std::vector<int> prob_mass(max - min + 1, 1);
        std::random_device rd;
        std::mt19937 gen(rd());
        std::discrete_distribution<> distribution(prob_mass.begin(), prob_mass.end());
        for (auto& x: storage) {
            auto tmp = static_cast<tScalar>(distribution(gen) + min);
            if (mask_by_activation_bit)
                x = tmp & activation_bit_mask;
            else 
                x = tmp;
        }
    } 


 private:
    int op_pack_ratio = sizeof(tScalarPack)/sizeof(tScalar), activation_bit;
    std::vector<tScalar> storage;
    std::vector<tScalar> storage_b1_rep, storage_b2_rep, storage_b2_first;
    std::vector<tScalar> storage_b2_first_sum;
    gemmlowp::MatrixMap<tScalar, tOrder> matrix_map;
};



template <typename tScalar, typename tScalarPack, gemmlowp::MapOrder tOrder>
class PackImageStorage: public PackMatrixWithStorage<tScalar, tScalarPack, tOrder> {
  public:
    using BaseMatrix = PackMatrixWithStorage<tScalar, tScalarPack, tOrder>;
    using BaseColMatrix = PackMatrixWithStorage<tScalar, tScalarPack, gemmlowp::MapOrder::ColMajor>;
    
    PackImageStorage() = default;
    PackImageStorage(int img_row, int img_col, int ch_num, int img_num, int activation_bit)
        : PackImageStorage(img_row, img_col, ch_num, img_num, activation_bit, 0) {}
    PackImageStorage(int img_row, int img_col, int ch_num, int img_num, int activation_bit, tScalar default_value)
            : BaseMatrix(ch_num, img_row * img_col * img_num, activation_bit, default_value), 
                img_row(img_row), img_col(img_col), img_num(img_num), ch_num(ch_num) {}
    
    
    const BaseColMatrix& Flatten() {
        int f_row = ch_num * img_row * img_col;
        int f_col = img_num;
        
        int img_stride = img_row * img_col;

        flatten_mat = BaseColMatrix(f_row, f_col, activation_bit);
        auto f_mat_map = flatten_mat.Map();
        auto cur_mat_map = this->ConstMap();
        for (int i = 0; i < cur_mat_map.rows(); i++) {
            for (int j = 0; j < cur_mat_map.cols(); j++) {
                int cur_img_idx = j / img_stride;
                int cur_feature_idx = j % img_stride;
                int flatten_col = cur_img_idx;
                int flatten_row = cur_feature_idx + i * img_stride;
                f_mat_map(flatten_row, flatten_col) = cur_mat_map(i,j);
            }
        }
        return flatten_mat;
    }

    // Only accept column major for output_matrix 
    void Img2Col(int padding, int stride, int kernel_size, BaseColMatrix* output_matrix) const {
        auto base_matrix_map = output_matrix->Map();
        const auto img_matrix_map = this->ConstMap();
        // Setting up constant for img2col
        int ch_stride = kernel_size * kernel_size;
        int img_stride = img_row * img_col;

        int img_row_min = -padding, img_row_max = img_row + padding; 
        int img_col_min = -padding, img_col_max = img_col + padding;
        int img_row_num = img_row_max - img_row_min, img_col_num = img_col_max - img_col_min;

        int row_patch_num = (img_row_num - kernel_size) / stride + 1;
        int col_patch_num = (img_col_num - kernel_size) / stride + 1;
        int patch_num = row_patch_num * col_patch_num;

        
        // The correct row_num should divide op_ratio exactly
        int correct_row_num = (ch_stride * ch_num) / this->GetOpPackRatio();
        if (correct_row_num * this->GetOpPackRatio() != ch_stride * ch_num) 
            correct_row_num++;
        correct_row_num *= this->GetOpPackRatio();
        // Check if size of output_matrix is correct
        // assert(base_matrix_map.rows() == ch_stride * ch_num);
        assert(base_matrix_map.rows() == correct_row_num);
        assert(base_matrix_map.cols() == img_num * patch_num);

        // Start copying the element
        for (int c = 0; c < ch_num; c++) {
            for (int n = 0; n < img_num; n++) {
                int cur_img_row = c;
                int cur_img_col = n * img_stride;

                int cur_img2col_row = c * ch_stride;
                int cur_img2col_col = n * patch_num;
                for (int p = 0; p < patch_num; p++) {
                    // 2-D index of patch in the image
                    int p_ridx = p / col_patch_num, p_cidx = p % col_patch_num;
                    int p_start_row = img_row_min + p_ridx * stride;
                    int p_start_col = img_col_min + p_cidx * stride;
                    int count = 0;
                    for (int i = 0; i < ch_stride; i++) {
                        // Calculate the row_index and col_index to retrieve value from img matrix
                        int cur_img_ridx = p_start_row + i / kernel_size;
                        int cur_img_cidx = p_start_col + i % kernel_size;
                        tScalar actual_value = 0;
                        // If not in region formed by padding 
                        if (IndexInImage(cur_img_ridx, cur_img_cidx)) {
                            int img_matrix_col_indx = cur_img_ridx * img_row + cur_img_cidx + cur_img_col;
                            int img_matrix_row_indx = cur_img_row;
                            actual_value = img_matrix_map(img_matrix_row_indx, img_matrix_col_indx);
                        }
                        // Calculate the index of target img2col matrix to put value
                        base_matrix_map(cur_img2col_row + count, cur_img2col_col + p) = actual_value;
                        count++;
                    }
                }
            }
        }
    }

    bool IndexInImage(int r, int c) const {return r >= 0 && r < img_row && c >= 0 && c < img_col;}

    const int GetImgRows() const {return img_row;}
    const int GetImgCols() const {return img_col;}
    const int GetChNum() const {return ch_num;}
    const int GetImgNum() const {return img_num;}
  private: 
    int img_row, img_col, ch_num, img_num, activation_bit;
    BaseColMatrix flatten_mat; 
};

#endif
