#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdint>
#include <iostream>
#include <random>
#include <vector>
#include "../public/gemmlowp.h" 

enum class MapOption {
    NORMAL,
    B1REP,
    B2REP,
    B2FIRST    
};


template <typename tScalar, typename tScalarPack, gemmlowp::MapOrder tOrder>
class PackMatrixWithStorage {
 public:
    PackMatrixWithStorage(int rows, int cols, int activation_bit)
            : storage(rows * cols), matrix_map(storage.data(), rows, cols), activation_bit(activation_bit) {}
    void MakeRandom() {
        static std::mt19937 random_engine;
        std::uniform_real_distribution<float> distribution(-1, 1);
        for (auto& x : storage) {
            x = static_cast<tScalar>(distribution(random_engine));
        }
    }
    
    void MakeInt8Random() {
        MakeDiscreteIntRandom(-(2 << (activation_bit - 3)), (2 << (activation_bit - 3)) - 1, false);
    }
    void MakeUint8Random() {
        MakeDiscreteIntRandom(0, 255, false);
    }
    void MakeTrinaryRandom() {
        MakeDiscreteIntRandom(-1, 1, false);
    }
    void MakeBinaryRandom() {
        MakeBinaryRandom(0, 1, false);
    }
    void ApplyActivationMask() {
        tScalar activation_bit_mask = GetActivationBitMask();
        for (auto x: storage) {
            x = x & activation_bit_mask;
        }
    }
    
    void PrepareAuxMatrixMap() {
        // Only do this on lhs matrix
        assert(matrix_map.kOrder == gemmlowp::MapOrder::RowMajor);

        tScalar activation_bit_mask = GetActivationBitMask();
        int rows = matrix_map.rows();
        int cols = matrix_map.cols();

        storage_b1_rep.resize(storage.size());
        storage_b2_rep.resize(storage.size());
        storage_b2_first.resize(storage.size());
        storage_b2_first_sum.resize(rows); 
        
        tScalar cur_b2_first_sum = 0;
        for (int i = 0; i < storage.size(); i++) {
            bool b1_is_1 = ((storage[i] & 1) > 0); 
            bool b2_is_1 = ((storage[i] & 2) > 0); 
            tScalar first_bit_2 = 1;
            tScalar all_0 = 0;
            tScalar all_1 = (~all_0) & activation_bit_mask;

            storage_b1_rep[i] = b1_is_1 ? all_1 : all_0;
            storage_b2_rep[i] = b2_is_1 ? all_1 : all_0;
            storage_b2_first[i] = b2_is_1 ? first_bit_2 : all_0;
            
            cur_b2_first_sum += storage_b2_first[i];
            if (i%cols == 0) {
                int indx = i/cols - 1;
                if (indx < 0) continue;
                storage_b2_first_sum[indx] = cur_b2_first_sum;
                cur_b2_first_sum = 0;
            }
            
        }
        
    }

    // The following offers differnt way to get MatrixMap 
    gemmlowp::MatrixMap<const tScalar, tOrder> ConstMap() const {
        return ConstMap(MapOption::NORMAL);
    }
    gemmlowp::MatrixMap<tScalar, tOrder> Map() {
        return Map(MapOption::NORMAL);
    }

    gemmlowp::MatrixMap<const tScalar, tOrder> ConstMap(MapOption opt) const {
        switch(opt) {
            case MapOption::NORMAL:
                return gemmlowp::MatrixMap<const tScalar, tOrder>(
                    storage.data(), matrix_map.rows(), matrix_map.cols());
            case MapOption::B1REP:
                return gemmlowp::MatrixMap<const tScalar, tOrder>(
                    storage_b1_rep.data(), matrix_map.rows(), matrix_map.cols());
            case MapOption::B2REP:
                return gemmlowp::MatrixMap<const tScalar, tOrder>(
                    storage_b2_rep.data(), matrix_map.rows(), matrix_map.cols());
            case MapOption::B2FIRST:
                return gemmlowp::MatrixMap<const tScalar, tOrder>(
                    storage_b2_first.data(), matrix_map.rows(), matrix_map.cols());
            default:
                assert(false);
        }
    }

    gemmlowp::MatrixMap<tScalar, tOrder> Map(MapOption opt) {
        switch(opt) {
            case MapOption::NORMAL:
                return gemmlowp::MatrixMap<tScalar, tOrder>(
                    storage.data(), matrix_map.rows(), matrix_map.cols());
            case MapOption::B1REP:
                return gemmlowp::MatrixMap<tScalar, tOrder>(
                    storage_b1_rep.data(), matrix_map.rows(), matrix_map.cols());
            case MapOption::B2REP:
                return gemmlowp::MatrixMap<tScalar, tOrder>(
                    storage_b2_rep.data(), matrix_map.rows(), matrix_map.cols());
            case MapOption::B2FIRST:
                return gemmlowp::MatrixMap<tScalar, tOrder>(
                    storage_b2_first.data(), matrix_map.rows(), matrix_map.cols());
            default:
                assert(false);
        }
    }
    
    
    gemmlowp::MatrixMap<const tScalarPack, tOrder> ConstMapPack() const {
        return ConstMapPack(MapOption::NORMAL);
    }


    gemmlowp::MatrixMap<const tScalarPack, tOrder> ConstMapPack(MapOption opt) const {
        int rows = matrix_map.rows();
        int cols = matrix_map.cols();
        if (matrix_map.kOrder == gemmlowp::MapOrder::RowMajor)
            cols = cols / op_pack_ratio; 
        if (matrix_map.kOrder == gemmlowp::MapOrder::ColMajor)
            rows = rows / op_pack_ratio; 

        switch(opt) {
            case MapOption::NORMAL:
                return gemmlowp::MatrixMap<const tScalarPack, tOrder>(
                    reinterpret_cast<const tScalarPack*>(storage.data()), rows, cols);
            case MapOption::B1REP:
                return gemmlowp::MatrixMap<const tScalarPack, tOrder>(
                    reinterpret_cast<const tScalarPack*>(storage_b1_rep.data()), rows, cols);
            case MapOption::B2REP:
                return gemmlowp::MatrixMap<const tScalarPack, tOrder>(
                    reinterpret_cast<const tScalarPack*>(storage_b2_rep.data()), rows, cols);
            case MapOption::B2FIRST:
                return gemmlowp::MatrixMap<const tScalarPack, tOrder>(
                    reinterpret_cast<const tScalarPack*>(storage_b2_first.data()), rows, cols);
            default:
                assert(true);
        }
    }

    gemmlowp::MatrixMap<tScalarPack, tOrder> MapPack() {
        return MapPack(MapOption::NORMAL);
    }

    gemmlowp::MatrixMap<tScalarPack, tOrder> MapPack(MapOption opt) {
        int rows = matrix_map.rows();
        int cols = matrix_map.cols();
        if (matrix_map.kOrder == gemmlowp::MapOrder::RowMajor)
            cols = cols / op_pack_ratio; 
        if (matrix_map.kOrder == gemmlowp::MapOrder::ColMajor)
            rows = rows / op_pack_ratio; 

        switch(opt) {
            case MapOption::NORMAL:
                return gemmlowp::MatrixMap<tScalarPack, tOrder>(
                    reinterpret_cast<tScalarPack*>(storage.data()), rows, cols);
            case MapOption::B1REP:
                return gemmlowp::MatrixMap<tScalarPack, tOrder>(
                    reinterpret_cast<tScalarPack*>(storage_b1_rep.data()), rows, cols);
            case MapOption::B2REP:
                return gemmlowp::MatrixMap<tScalarPack, tOrder>(
                    reinterpret_cast<tScalarPack*>(storage_b2_rep.data()), rows, cols);
            case MapOption::B2FIRST:
                return gemmlowp::MatrixMap<tScalarPack, tOrder>(
                    reinterpret_cast<tScalarPack*>(storage_b2_first.data()), rows, cols);
            default:
                assert(true);
        }
    }

    

    const std::vector<tScalar>& Storage() const { return storage; }
    std::vector<tScalar>& Storage() { return storage; }
    int GetActivationBit() const {return activation_bit; } 
    tScalar GetActivationBitMask() {
        tScalar activation_bit_mask = 0;
        tScalar one = 1;
        for (int i = 0; i < activation_bit; i++) 
            activation_bit_mask += one << i;
        return activation_bit_mask;
    }
    tScalarPack GetActivationBitMaskPack() {
        tScalarPack activation_bit_mask = static_cast<tScalar>(GetActivationBitMask());
        tScalarPack activation_bit_mask_pack = 0;
        for (int i = 0; i < op_pack_ratio; i++) {
            activation_bit_mask_pack += (activation_bit_mask << (i*8));
        } 
        return activation_bit_mask_pack;
    }
    const std::vector<tScalar>& GetB2FirstSum() const {return storage_b2_first_sum; }
    int GetOpPackRatio() const {return op_pack_ratio; }

 protected:
    void MakeDiscreteIntRandom(int min, int max, bool mask_by_activation_bit) {
        tScalar activation_bit_mask  = GetActivationBitMask();

        std::vector<int> prob_mass(max - min + 1, 1);
        std::random_device rd;
        std::mt19937 gen(rd());
        std::discrete_distribution<> distribution(prob_mass.begin(), prob_mass.end());
        for (auto& x: storage) {
            auto tmp = static_cast<tScalar>(distribution(gen) + min);
            if (mask_by_activation_bit)
                x = tmp & activation_bit_mask;
            else 
                x = tmp;
        }
    } 


 private:
    int op_pack_ratio = sizeof(tScalarPack)/sizeof(tScalar), activation_bit;
    std::vector<tScalar> storage;
    std::vector<tScalar> storage_b1_rep, storage_b2_rep, storage_b2_first;
    std::vector<tScalar> storage_b2_first_sum;
    gemmlowp::MatrixMap<tScalar, tOrder> matrix_map;
};

template <typename tScalar, gemmlowp::MapOrder tOrder>
std::ostream& operator<<(std::ostream& s,
                                                 const gemmlowp::MatrixMap<tScalar, tOrder>& m) {
    for (int i = 0; i < m.rows(); i++) {
        for (int j = 0; j < m.cols(); j++) {
            if (j) {
                s << '\t';
            }
            s << static_cast<int>(m(i, j));
        }
        s << '\n';
    }
    return s;
}

template <typename tScalar, typename tScalarPack, gemmlowp::MapOrder tOrder>
std::ostream& operator<<(std::ostream& s,
                                                 const PackMatrixWithStorage<tScalar, tScalarPack, tOrder>& m) {
    return s << m.ConstMap();
}

