#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdint>
#include <iostream>
#include <random>
#include <vector>
#include <cstring>
#include <cstdlib>
#include <ctime>
#include "../public/gemmlowp.h"
#include "matrix_storage.h"

template <typename tScalar, gemmlowp::MapOrder tLhsOrder, gemmlowp::MapOrder tRhsOrder,
                    gemmlowp::MapOrder tResultOrder>
void NormalMatrixMultiplication(
        const gemmlowp::MatrixMap<const tScalar, tLhsOrder>& lhs,
        const gemmlowp::MatrixMap<const tScalar, tRhsOrder>& rhs,
        gemmlowp::MatrixMap<tScalar, tResultOrder>* result) {
    assert(lhs.cols() == rhs.rows());
    assert(lhs.rows() == result->rows());
    assert(rhs.cols() == result->cols());
    for (int i = 0; i < lhs.rows(); i++) {
        for (int k = 0; k < rhs.cols(); k++) {
            (*result)(i, k) = 0;
            for (int j = 0; j < lhs.cols(); j++) {
                (*result)(i, k) += lhs(i, j) * rhs(j, k);
            }
        }
    }
}

template <typename tScalar, typename tScalrPack, gemmlowp::MapOrder tLhsOrder, 
            gemmlowp::MapOrder tRhsOrder, gemmlowp::MapOrder tResultOrder>
void PackMatrixMultiplicationFull(
        PackMatrixWithStorage<tScalar, tScalrPack, tLhsOrder> lhs, 
        PackMatrixWithStorage<tScalar, tScalrPack, tRhsOrder> rhs,
        PackMatrixWithStorage<tScalar, tScalrPack, tResultOrder>* result
        ) {
    
    assert(lhs.GetOpPackRatio() == rhs.GetOpPackRatio());

    const auto lhs_b1_rep_map_pack = lhs.ConstMapPack(MapOption::B1REP);
    const auto lhs_b2_rep_map_pack = lhs.ConstMapPack(MapOption::B2REP);
    const auto lhs_b2_first_map_pack = lhs.ConstMapPack(MapOption::B2FIRST);
    const auto rhs_map_pack = rhs.ConstMapPack();

    auto result_map = result->Map();

    assert(lhs_b1_rep_map_pack.cols() == rhs_map_pack.rows());
    assert(lhs_b1_rep_map_pack.rows() == result_map.rows());
    assert(rhs_map_pack.cols() == result_map.cols());

    int pack_ratio = lhs.GetOpPackRatio();
    int pack_shift = sizeof(tScalar) * 8;
    
    for (int i = 0; i < lhs_b1_rep_map_pack.rows(); i++) {
        for (int k = 0; k < rhs_map_pack.cols(); k++) {
            tScalrPack cur_sum = 0, scalar_sum = 0;
            for (int j = 0; j < lhs_b1_rep_map_pack.cols(); j++) {
                tScalrPack sum = rhs_map_pack(j, k);
                sum &= lhs_b1_rep_map_pack(i, j);
                sum ^= lhs_b2_rep_map_pack(i, j);
                sum += lhs_b2_first_map_pack(i, j);
                cur_sum += sum;
            }
            for (int l = 0; l < pack_ratio; l++) {
                scalar_sum += cur_sum;
                cur_sum = cur_sum >> pack_shift;
            } 
            result_map(i,k) = static_cast<tScalar>(scalar_sum);
        }
    }

    return;
}

template <typename tScalar, typename tScalrPack, gemmlowp::MapOrder tLhsOrder, 
            gemmlowp::MapOrder tRhsOrder, gemmlowp::MapOrder tResultOrder>
void PackMatrixMultiplicationMask(
        PackMatrixWithStorage<tScalar, tScalrPack, tLhsOrder> lhs, 
        PackMatrixWithStorage<tScalar, tScalrPack, tRhsOrder> rhs,
        PackMatrixWithStorage<tScalar, tScalrPack, tResultOrder>* result
        ) {
    
    assert(lhs.GetOpPackRatio() == rhs.GetOpPackRatio());
    assert(lhs.GetActivationBitMask() == rhs.GetActivationBitMask());

    const auto lhs_b1_rep_map_pack = lhs.ConstMapPack(MapOption::B1REP);
    const auto lhs_b2_rep_map_pack = lhs.ConstMapPack(MapOption::B2REP);
    // const auto lhs_b2_first_map_pack = lhs.ConstMapPack(MapOption::B2FIRST);
    const auto rhs_map_pack = rhs.ConstMapPack();
    const auto lhs_b2_first_sum = lhs.GetB2FirstSum();

    auto result_map = result->Map();

    assert(lhs_b1_rep_map_pack.cols() == rhs_map_pack.rows());
    assert(lhs_b1_rep_map_pack.rows() == result_map.rows());
    assert(rhs_map_pack.cols() == result_map.cols());

    tScalar mask = lhs.GetActivationBitMask();
    tScalrPack mask_pack = lhs.GetActivationBitMaskPack();
    int neg_shift = sizeof(tScalar) * 8 - lhs.GetActivationBit();
    int pack_ratio = lhs.GetOpPackRatio();
    int pack_shift = sizeof(tScalar) * 8;
    std::cout << "bit mask pack = " << mask_pack << "\n"; 
    for (int i = 0; i < lhs_b1_rep_map_pack.rows(); i++) {
        for (int k = 0; k < rhs_map_pack.cols(); k++) {
            tScalrPack cur_sum = 0, scalar_sum = 0;
            for (int j = 0; j < lhs_b1_rep_map_pack.cols(); j++) {
                tScalrPack sum = rhs_map_pack(j, k);
                sum &= lhs_b1_rep_map_pack(i, j);
                sum ^= lhs_b2_rep_map_pack(i, j);
                cur_sum += sum;
                cur_sum &= mask_pack;
            }
            for (int l = 0; l < pack_ratio; l++) {
                tScalar tmp = (static_cast<tScalar>(cur_sum & mask)) << neg_shift;
                scalar_sum += (tmp >> neg_shift);
                cur_sum = cur_sum >> pack_shift;
            } 
            result_map(i,k) = static_cast<tScalar>(scalar_sum) + lhs_b2_first_sum[i];
        }
    }

    return;
}
