#ifndef MATRIX_MULTIPLICATION_H_
#define MATRIX_MULTIPLICATION_H_

#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdint>
#include <iostream>
#include <random>
#include <vector>
#include <cstring>
#include <cstdlib>
#include <ctime>
#include "../gemmlowp/public/gemmlowp.h"
#include "matrix_storage.h"

enum class MMOPTION {
    PACK,
    NORMAL,
    GEMMLOWP
};

// Guard global param only accessed locally
namespace {
    MMOPTION mm_engine_option = MMOPTION::NORMAL; 
};

void SetMatMulEngine(MMOPTION opt) {
    mm_engine_option = opt;
}


template <typename tScalar, gemmlowp::MapOrder tLhsOrder, gemmlowp::MapOrder tRhsOrder,
                    gemmlowp::MapOrder tResultOrder>
void NormalMatrixMultiplication(
        const gemmlowp::MatrixMap<const tScalar, tLhsOrder>& lhs,
        const gemmlowp::MatrixMap<const tScalar, tRhsOrder>& rhs,
        gemmlowp::MatrixMap<tScalar, tResultOrder>* result) {
    assert(lhs.cols() == rhs.rows());
    assert(lhs.rows() == result->rows());
    assert(rhs.cols() == result->cols());
    for (int i = 0; i < lhs.rows(); i++) {
        for (int k = 0; k < rhs.cols(); k++) {
            (*result)(i, k) = 0;
            for (int j = 0; j < lhs.cols(); j++) {
                (*result)(i, k) += lhs(i, j) * rhs(j, k);
            }
        }
    }
}

template <typename tScalar, typename tScalrPack, gemmlowp::MapOrder tLhsOrder, 
            gemmlowp::MapOrder tRhsOrder, gemmlowp::MapOrder tResultOrder>
void PackMatrixMultiplicationFull(
        PackMatrixWithStorage<tScalar, tScalrPack, tLhsOrder> lhs, 
        PackMatrixWithStorage<tScalar, tScalrPack, tRhsOrder> rhs,
        PackMatrixWithStorage<tScalar, tScalrPack, tResultOrder>* result
        ) {
    
    assert(lhs.GetOpPackRatio() == rhs.GetOpPackRatio());

    const auto lhs_b1_rep_map_pack = lhs.ConstMapPack(MapOption::B1REP);
    const auto lhs_b2_rep_map_pack = lhs.ConstMapPack(MapOption::B2REP);
    const auto lhs_b2_first_map_pack = lhs.ConstMapPack(MapOption::B2FIRST);
    const auto rhs_map_pack = rhs.ConstMapPack();

    auto result_map = result->Map();

    assert(lhs_b1_rep_map_pack.cols() == rhs_map_pack.rows());
    assert(lhs_b1_rep_map_pack.rows() == result_map.rows());
    assert(rhs_map_pack.cols() == result_map.cols());

    int pack_ratio = lhs.GetOpPackRatio();
    int pack_shift = sizeof(tScalar) * 8;
    
    for (int i = 0; i < lhs_b1_rep_map_pack.rows(); i++) {
        for (int k = 0; k < rhs_map_pack.cols(); k++) {
            tScalrPack cur_sum = 0, scalar_sum = 0;
            for (int j = 0; j < lhs_b1_rep_map_pack.cols(); j++) {
                tScalrPack sum = rhs_map_pack(j, k);
                sum &= lhs_b1_rep_map_pack(i, j);
                sum ^= lhs_b2_rep_map_pack(i, j);
                sum += lhs_b2_first_map_pack(i, j);
                cur_sum += sum;
            }
            for (int l = 0; l < pack_ratio; l++) {
                scalar_sum += cur_sum;
                cur_sum = cur_sum >> pack_shift;
            } 
            result_map(i,k) = static_cast<tScalar>(scalar_sum);
        }
    }

    return;
}

template <typename tScalar, typename tScalrPack, gemmlowp::MapOrder tLhsOrder, 
            gemmlowp::MapOrder tRhsOrder, gemmlowp::MapOrder tResultOrder>
void PackMatrixMultiplicationMask(
        PackMatrixWithStorage<tScalar, tScalrPack, tLhsOrder> lhs, 
        PackMatrixWithStorage<tScalar, tScalrPack, tRhsOrder> rhs,
        PackMatrixWithStorage<tScalar, tScalrPack, tResultOrder>* result
        ) {
    
    assert(lhs.GetOpPackRatio() == rhs.GetOpPackRatio());
    assert(lhs.GetActivationBitMask() == rhs.GetActivationBitMask());

    const auto lhs_b1_rep_map_pack = lhs.ConstMapPack(MapOption::B1REP);
    const auto lhs_b2_rep_map_pack = lhs.ConstMapPack(MapOption::B2REP);
    const auto rhs_map_pack = rhs.ConstMapPack();
    const auto lhs_b2_first_sum = lhs.GetB2FirstSum();

    auto result_map = result->Map();

    assert(lhs_b1_rep_map_pack.cols() == rhs_map_pack.rows());
    assert(lhs_b1_rep_map_pack.rows() == result_map.rows());
    assert(rhs_map_pack.cols() == result_map.cols());

    tScalar mask = lhs.GetActivationBitMask();
    tScalrPack mask_pack = lhs.GetActivationBitMaskPack();
    int neg_shift = sizeof(tScalar) * 8 - lhs.GetActivationBit();
    int pack_ratio = lhs.GetOpPackRatio();
    int pack_shift = sizeof(tScalar) * 8;
    for (int i = 0; i < lhs_b1_rep_map_pack.rows(); i++) {
        for (int k = 0; k < rhs_map_pack.cols(); k++) {
            tScalrPack cur_sum = 0, scalar_sum = 0;
            for (int j = 0; j < lhs_b1_rep_map_pack.cols(); j++) {
                tScalrPack sum = rhs_map_pack(j, k) & mask_pack;
                sum &= lhs_b1_rep_map_pack(i, j);
                sum ^= lhs_b2_rep_map_pack(i, j);
                cur_sum += sum;
                cur_sum &= mask_pack;
            }
            for (int l = 0; l < pack_ratio; l++) {
                tScalar tmp = (static_cast<tScalar>(cur_sum & mask)) << neg_shift;
                scalar_sum += (tmp >> neg_shift);
                cur_sum = cur_sum >> pack_shift;
            } 
            result_map(i,k) = static_cast<tScalar>(scalar_sum) + lhs_b2_first_sum[i];
        }
    }

    return;
}

template <typename tScalar, typename tScalrPack, gemmlowp::MapOrder tLhsOrder, 
            gemmlowp::MapOrder tRhsOrder, gemmlowp::MapOrder tResultOrder>
void MatMulForward(PackMatrixWithStorage<tScalar, tScalrPack, tLhsOrder> lhs, 
        PackMatrixWithStorage<tScalar, tScalrPack, tRhsOrder> rhs,
        PackMatrixWithStorage<tScalar, tScalrPack, tResultOrder>* result
        ) {
    auto matrix_map = result->Map();
    switch (mm_engine_option) {
        case MMOPTION::NORMAL:
            NormalMatrixMultiplication(lhs.ConstMap(), rhs.ConstMap(), &matrix_map);
            return;
        case MMOPTION::PACK:
            PackMatrixMultiplicationMask(lhs, rhs, result);
            return;
        case MMOPTION::GEMMLOWP:
            // Not implemented
            assert(false);
        default:
            // Shouldn't get here
            assert(false);
    }
}

// Overload if the type of scalar is float
template<gemmlowp::MapOrder tLhsOrder, gemmlowp::MapOrder tRhsOrder, gemmlowp::MapOrder tResultOrder>
void MatMulForward(PackMatrixWithStorage<float, float, tLhsOrder> lhs, 
        PackMatrixWithStorage<float, float, tRhsOrder> rhs,
        PackMatrixWithStorage<float, float, tResultOrder>* result
        ) {

    auto matrix_map = result->Map();
    NormalMatrixMultiplication(lhs.ConstMap(), rhs.ConstMap(), &matrix_map);
    return;
}


#endif
