// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// pack_avx.h: optimized AVX specializations of the templates in pack.h.

#ifndef GEMMLOWP_INTERNAL_PACK_AVX_SIGN_DEPTH_MAJOR_H_
#define GEMMLOWP_INTERNAL_PACK_AVX_SIGN_DEPTH_MAJOR_H_

#include <iostream>
#include <immintrin.h>
#include <cstring>
#include "pack.h"
#include "kernel_avx_sign_int8acc.h"


namespace gemmlowp {

// TODO: Add DepthMajorUint8SideMap

typedef SideMap<const std::int8_t, SideMapOrder::DepthMajor>
    DepthMajorInt8SideMap;

template <int Cells>
using DepthMajorSideFormatNCells32x1Int8 =
    KernelSideFormatInt8Inputs<CellFormat<32, 1, CellOrder::DepthMajor>, Cells>;

template <>
class PackingRegisterBlock<
    DepthMajorInt8SideMap,
    PackedSideBlock<DepthMajorSideFormatNCells32x1Int8<3>>>
    : public PackingRegisterBlockBase<
          DepthMajorInt8SideMap,
          PackedSideBlock<DepthMajorSideFormatNCells32x1Int8<3>>> {
 public:
  typedef DepthMajorSideFormatNCells32x1Int8<3> KernelSideFormat;
  typedef typename KernelSideFormat::Cell CellFormat;
  static const int kCells = KernelSideFormat::kCells;
  static const int kCellWidth = CellFormat::kWidth;
  static const int kKernelWidth = CellFormat::kWidth * kCells;
  static const int kCellDepth = CellFormat::kDepth;
  static const int kCellSize = CellFormat::kSize;

  void Pack(PackedSideBlock<KernelSideFormat> *dst, int start_width) {
    std::int8_t* ori_dst_ptr = reinterpret_cast<std::int8_t*>(dst->current_data());
    const int depth_stride = this->complete_src_.depth_stride();
    int depth_step = 4;
    
    // Do three cells at a time
    for (int cell_start_width = 0; cell_start_width < kKernelWidth;
         cell_start_width += kCellWidth * kCells) {
      auto dst_ptr = ori_dst_ptr;
      for (int cell_start_depth = 0; cell_start_depth < kRegisterSize;
           cell_start_depth += depth_step) {
        const std::int8_t *src_data =
            this->complete_src_.data(cell_start_width, cell_start_depth);
        
        __m256i ymm1 =
            _mm256_loadu_si256(reinterpret_cast<const __m256i *>(&src_data[0]));
        __m256i ymm2 =
            _mm256_loadu_si256(reinterpret_cast<const __m256i *>(&src_data[32]));
        __m256i ymm3 =
            _mm256_loadu_si256(reinterpret_cast<const __m256i *>(&src_data[64]));

        
        // cur_local_depth = 1;
        __m256i ymm4 =
            _mm256_loadu_si256(reinterpret_cast<const __m256i *>(&src_data[depth_stride]));
        __m256i ymm5 =
            _mm256_loadu_si256(reinterpret_cast<const __m256i *>(&src_data[depth_stride + 32]));
        __m256i ymm6 =
            _mm256_loadu_si256(reinterpret_cast<const __m256i *>(&src_data[depth_stride + 64]));


        // cur_local_depth = 2;
        __m256i ymm7 =
            _mm256_loadu_si256(reinterpret_cast<const __m256i *>(&src_data[2 * depth_stride]));
        __m256i ymm8 =
            _mm256_loadu_si256(reinterpret_cast<const __m256i *>(&src_data[2 * depth_stride + 32]));
        __m256i ymm9 =
            _mm256_loadu_si256(reinterpret_cast<const __m256i *>(&src_data[2 * depth_stride + 64]));

        // cur_local_depth = 3;
        __m256i ymm10 =
            _mm256_loadu_si256(reinterpret_cast<const __m256i *>(&src_data[3 * depth_stride]));
        __m256i ymm11 =
            _mm256_loadu_si256(reinterpret_cast<const __m256i *>(&src_data[3 * depth_stride + 32]));
        __m256i ymm12 =
            _mm256_loadu_si256(reinterpret_cast<const __m256i *>(&src_data[3 * depth_stride + 64]));

        _mm256_storeu_si256(reinterpret_cast<__m256i *>(&dst_ptr[0]), ymm1);
        _mm256_storeu_si256(reinterpret_cast<__m256i *>(&dst_ptr[1 * kCellSize]), ymm2);
        _mm256_storeu_si256(reinterpret_cast<__m256i *>(&dst_ptr[2 * kCellSize]), ymm3);

        _mm256_storeu_si256(reinterpret_cast<__m256i *>(&dst_ptr[3 * kCellSize]), ymm4);
        _mm256_storeu_si256(reinterpret_cast<__m256i *>(&dst_ptr[4 * kCellSize]), ymm5);
        _mm256_storeu_si256(reinterpret_cast<__m256i *>(&dst_ptr[5 * kCellSize]), ymm6);

        _mm256_storeu_si256(reinterpret_cast<__m256i *>(&dst_ptr[6 * kCellSize]), ymm7);
        _mm256_storeu_si256(reinterpret_cast<__m256i *>(&dst_ptr[7 * kCellSize]), ymm8);
        _mm256_storeu_si256(reinterpret_cast<__m256i *>(&dst_ptr[8 * kCellSize]), ymm9);

        _mm256_storeu_si256(reinterpret_cast<__m256i *>(&dst_ptr[9 * kCellSize]), ymm10);
        _mm256_storeu_si256(reinterpret_cast<__m256i *>(&dst_ptr[10 * kCellSize]), ymm11);
        _mm256_storeu_si256(reinterpret_cast<__m256i *>(&dst_ptr[11 * kCellSize]), ymm12);


        dst_ptr += 12 * kCellSize;
      }

#ifdef DEBUG_OUTPUT 
    std::cout << "how does current_data look like now at?\n";
    for (int pos = 0; pos < 200; pos++) {
        std::int8_t* val_ptr = reinterpret_cast<std::int8_t*>(dst->get_data_at_pos(pos));
        std::cout << "pos = " << pos << " val = " << (int)*val_ptr << "\n";
    }
#endif

    }
    dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);

  
  }


};

// Pack format for 4x2 rhs format
template <int Cells>
using RhsDepthMajorSideFormatNCells4x1Int8 =
    KernelSideFormatInt8Inputs<CellFormat<4, 1, CellOrder::DepthMajor>, Cells>;

template <>
class PackingRegisterBlock<
    DepthMajorInt8SideMap,
    PackedSideBlock<RhsDepthMajorSideFormatNCells4x1Int8<1>>>
    : public PackingRegisterBlockBase<
          DepthMajorInt8SideMap,
          PackedSideBlock<RhsDepthMajorSideFormatNCells4x1Int8<1>>> {
 public:
  typedef RhsDepthMajorSideFormatNCells4x1Int8<1> KernelSideFormat;
  typedef typename KernelSideFormat::Cell CellFormat;
  static const int kCells = KernelSideFormat::kCells;
  static const int kCellWidth = CellFormat::kWidth;
  static const int kKernelWidth = CellFormat::kWidth * kCells;
  static const int kCellDepth = CellFormat::kDepth;
  static const int kCellSize = CellFormat::kSize;

  void Pack(PackedSideBlock<KernelSideFormat> *dst, int start_width) {
    std::int8_t* dst_ptr = reinterpret_cast<std::int8_t*>(dst->current_data());
    const int depth_stride = this->complete_src_.depth_stride();
    for (int cell_start_width = 0; cell_start_width < kKernelWidth;
        cell_start_width += kCellWidth) {
        
      // 255 << 24, 255 << 0, 255 < 8, 255 << 16 
      std::int32_t mask_const[4] = {-1, 255, 65535, 16777215};
      int shift = (kKernelWidth - cell_start_width) % 4; 
      int mask = mask_const[shift];
      
      std::int8_t *src_data =
        const_cast<std::int8_t*>(this->complete_src_.data(cell_start_width, 0));
      std::int32_t* cur_dst_ptr  = reinterpret_cast<std::int32_t*>(dst_ptr);
      // Manual loop unrolling
      *(cur_dst_ptr) = *(reinterpret_cast<std::int32_t*>(src_data)) & mask; 
      *(cur_dst_ptr + 1) = *(reinterpret_cast<std::int32_t*>(src_data + depth_stride)) & mask; 
      *(cur_dst_ptr + 2) = *(reinterpret_cast<std::int32_t*>(src_data + 2 * depth_stride)) & mask; 
      *(cur_dst_ptr + 3) = *(reinterpret_cast<std::int32_t*>(src_data + 3 * depth_stride)) & mask; 

      *(cur_dst_ptr + 4) = *(reinterpret_cast<std::int32_t*>(src_data + 4 * depth_stride)) & mask; 
      *(cur_dst_ptr + 5) = *(reinterpret_cast<std::int32_t*>(src_data + 5 * depth_stride)) & mask; 
      *(cur_dst_ptr + 6) = *(reinterpret_cast<std::int32_t*>(src_data + 6 * depth_stride)) & mask; 
      *(cur_dst_ptr + 7) = *(reinterpret_cast<std::int32_t*>(src_data + 7 * depth_stride)) & mask; 

      *(cur_dst_ptr + 8) = *(reinterpret_cast<std::int32_t*>(src_data + 8 * depth_stride)) & mask; 
      *(cur_dst_ptr + 9) = *(reinterpret_cast<std::int32_t*>(src_data + 9 * depth_stride)) & mask; 
      *(cur_dst_ptr + 10) = *(reinterpret_cast<std::int32_t*>(src_data + 10 * depth_stride)) & mask; 
      *(cur_dst_ptr + 11) = *(reinterpret_cast<std::int32_t*>(src_data + 11 * depth_stride)) & mask; 

      *(cur_dst_ptr + 12) = *(reinterpret_cast<std::int32_t*>(src_data + 12 * depth_stride)) & mask; 
      *(cur_dst_ptr + 13) = *(reinterpret_cast<std::int32_t*>(src_data + 13 * depth_stride)) & mask; 
      *(cur_dst_ptr + 14) = *(reinterpret_cast<std::int32_t*>(src_data + 14 * depth_stride)) & mask; 
      *(cur_dst_ptr + 15) = *(reinterpret_cast<std::int32_t*>(src_data + 15 * depth_stride)) & mask; 

      *(cur_dst_ptr + 16) = *(reinterpret_cast<std::int32_t*>(src_data + 16 * depth_stride)) & mask; 
      *(cur_dst_ptr + 17) = *(reinterpret_cast<std::int32_t*>(src_data + 17 * depth_stride)) & mask; 
      *(cur_dst_ptr + 18) = *(reinterpret_cast<std::int32_t*>(src_data + 18 * depth_stride)) & mask; 
      *(cur_dst_ptr + 19) = *(reinterpret_cast<std::int32_t*>(src_data + 19 * depth_stride)) & mask; 

      *(cur_dst_ptr + 20) = *(reinterpret_cast<std::int32_t*>(src_data + 20 * depth_stride)) & mask; 
      *(cur_dst_ptr + 21) = *(reinterpret_cast<std::int32_t*>(src_data + 21 * depth_stride)) & mask; 
      *(cur_dst_ptr + 22) = *(reinterpret_cast<std::int32_t*>(src_data + 22 * depth_stride)) & mask; 
      *(cur_dst_ptr + 23) = *(reinterpret_cast<std::int32_t*>(src_data + 23 * depth_stride)) & mask; 

      *(cur_dst_ptr + 24) = *(reinterpret_cast<std::int32_t*>(src_data + 24 * depth_stride)) & mask; 
      *(cur_dst_ptr + 25) = *(reinterpret_cast<std::int32_t*>(src_data + 25 * depth_stride)) & mask; 
      *(cur_dst_ptr + 26) = *(reinterpret_cast<std::int32_t*>(src_data + 26 * depth_stride)) & mask; 
      *(cur_dst_ptr + 27) = *(reinterpret_cast<std::int32_t*>(src_data + 27 * depth_stride)) & mask; 

      *(cur_dst_ptr + 28) = *(reinterpret_cast<std::int32_t*>(src_data + 28 * depth_stride)) & mask; 
      *(cur_dst_ptr + 29) = *(reinterpret_cast<std::int32_t*>(src_data + 29 * depth_stride)) & mask; 
      *(cur_dst_ptr + 30) = *(reinterpret_cast<std::int32_t*>(src_data + 30 * depth_stride)) & mask; 
      *(cur_dst_ptr + 31) = *(reinterpret_cast<std::int32_t*>(src_data + 31 * depth_stride)) & mask; 
     

      dst_ptr += 32 * kCellSize * kCells;
    }
    dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);

#ifdef DEBUG_OUTPUT 
    std::cout << "how does current_data look like now at?\n";
    for (int pos = 0; pos < 200; pos++) {
        std::int8_t* val_ptr = reinterpret_cast<std::int8_t*>(dst->get_data_at_pos(pos));
        std::cout << "pos = " << pos << " val = " << (int)*val_ptr << "\n";
    }
#endif

  }
};

}  // namespace gemmlowp

#endif  // GEMMLOWP_INTERNAL_PACK_AVX_SIGN_DEPTH_MAJOR_H_
