//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// kernel_SSE.h: a collection of Intel SSE optimized kernels.
// Check in kernel_default.h which one(s) are actually used by default.
// Others are mere experiments; they are still covered by tests
// in case they might be useful some day.
//

#ifndef GEMMLOWP_INTERNAL_KERNEL_AVX_SIGN_INT8_ACC_H_
#define GEMMLOWP_INTERNAL_KERNEL_AVX_SIGN_INT8_ACC_H_

#include "kernel.h"

#include <cstring>
#include <iostream>
#include <cassert>


namespace {

void PrintContent(std::int8_t* arr, std::string name) {
    std::cout << "The content of " << name << " is\n";
    for (int i = 0; i < 32; i++) {
        std::cout << (int)*(arr+i) << " ";
    }
    std::cout << "\n";
}
}


namespace gemmlowp {

#ifdef GEMMLOWP_AVX2_64
struct AVX2_64_Kernel24x8Depth2_Int8Operands_Int8Inputs_Int8Acc : 
    KernelBase {
  typedef KernelFormat<KernelSideFormatInt8Inputs<CellFormat<32, 1, CellOrder::DepthMajor>, 3>,
                       KernelSideFormatInt8Inputs<CellFormat<4, 1, CellOrder::DepthMajor>, 1>>
      Format;

  const char *Name() const override { return "AVX, 32 x 3, 4 x 1, depth 1 integer, int8 acc!!!!"; }

  void Run(dstType *dst_ptr, std::size_t dst_row_stride, std::size_t dst_col_stride,
           const std::uint8_t *lhs_ptr, const std::uint8_t *rhs_ptr, std::size_t start_depth,
           std::size_t run_depth) const override {
    ScopedProfilingLabel label("optimized kernel");
    assert(dst_row_stride == 1);
    std::int64_t run_depth_cells = run_depth / Format::kDepth;
    const std::int64_t dst_col_stride_q = dst_col_stride;
    std::int64_t* tmp = new std::int64_t;
    std::int8_t *arr_256 = new std::int8_t[32];
    std::int8_t pshuf_reorg_elem_to_hom_block[32] =
        {0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,
        2,2,2,2,2,2,2,2,3,3,3,3,3,3,3,3};
    std::int8_t *pshuf_elem_reorg_op = pshuf_reorg_elem_to_hom_block;
    //std::int8_t pshuf_duplicate_first_byte[] = 
    //    {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
    //    0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
    //std::int8_t* pshuf_dup_fb_op = pshuf_duplicate_first_byte;
    //std::int8_t arr_256[32]; 
    /* Main loop */

    // A 1x4 cell of Rhs is stored in 8bit in ymm1 .
    // A 96 x 1 block of 3 32x1 cells Lhs is stored in 8bit in ymm0, replaced
    // every Iteration.
    // accumulators is stored in 8bit in xmm4--xmm15.
    //
    //                   +-------+---------------+-------+
    //              Rhs  |ymm1[0]|ymm1[1]|ymm1[2]|ymm1[3]|
    //                   +-------+-------+-------+-------+
    //
    //                   |       |       |       |       |
    //
    //    Lhs            |       |       |       |       |
    //
    //  +--+--+ - - - -  +-------+-------+-------+-------+
    //  |ymm0 |          | ymm4  | ymm5  | ymm6  | ymm7  |
    //  |ymm0 | (Iter1)  | ymm4  | ymm5  | ymm6  | ymm7  |
    //  |ymm0 |          | ymm4  | ymm5  | ymm6  | ymm7  |
    //  |ymm0 |          | ymm4  | ymm5  | ymm6  | ymm7  |
    //  +--+--+ - - - -  +-------+-------+-------+-------+
    //  |ymm0 |          | ymm8  | ymm9  | ymm10 | ymm11 |
    //  |ymm0 | (Iter3)  | ymm8  | ymm9  | ymm10 | ymm11 |
    //  |ymm0 |          | ymm8  | ymm9  | ymm10 | ymm11 |
    //  |ymm0 |          | ymm8  | ymm9  | ymm10 | ymm11 |
    //  +--+--+ - - - -  +-------+-------+-------+-------+
    //  |ymm0 |          | ymm12 | ymm13 | ymm14 | ymm15 |
    //  |ymm0 | (Iter5)  | ymm12 | ymm13 | ymm14 | ymm15 |
    //  |ymm0 |          | ymm12 | ymm13 | ymm14 | ymm15 |
    //  |ymm0 |          | ymm12 | ymm13 | ymm14 | ymm15 |
    //  +--+--+ - - - -  +-------+-------+-------+-------+
    //                              Accumulator

    asm volatile(

        // Set accumulators to zero.
        "vpxor %%ymm1, %%ymm1, %%ymm1 \n\t"
        "vpxor %%ymm3, %%ymm3, %%ymm3 \n\t"
        "vpxor %%ymm4, %%ymm4, %%ymm4 \n\t"    // zero accumulators
        "vpxor %%ymm5, %%ymm5, %%ymm5 \n\t"    // zero accumulators
        "vpxor %%ymm6, %%ymm6, %%ymm6 \n\t"    // zero accumulators
        "vpxor %%ymm7, %%ymm7, %%ymm7 \n\t"    // zero accumulators
        "vpxor %%ymm8, %%ymm8, %%ymm8 \n\t"    // zero accumulators
        "vpxor %%ymm9, %%ymm9, %%ymm9 \n\t"    // zero accumulators
        "vpxor %%ymm10, %%ymm10, %%ymm10\n\t"  // zero accumulators
        "vpxor %%ymm11, %%ymm11, %%ymm11\n\t"  // zero accumulators
        "vpxor %%ymm12, %%ymm12, %%ymm12\n\t"  // zero accumulators
        "vpxor %%ymm13, %%ymm13, %%ymm13\n\t"  // zero accumulators
        "vpxor %%ymm14, %%ymm14, %%ymm14\n\t"  // zero accumulators
        "vpxor %%ymm15, %%ymm15, %%ymm15\n\t"  // zero accumulators

        "movq  %[run_depth_cells], %%r14 \n\t"  // load cell depth r14
        //"movq $0, %%r15 \n\t"
        //"movq  $0x04, %%r14 \n\t"  // load cell depth r14
        
        // Prepare OP to duplicate the first byte
        // "vmovdqu (%[pshuf_dup_fb_op]), (%[pshuf_dup_fb_op]) \n\t"

        // The loop which processes each cell at a time
        "outerLoop%=: \n\t"  // outer loop unroll
        
        "vmovdqu (%[rhs_ptr]), %%ymm1 \n\t" // Load 8 blokcs of rhs 

        // Duplicate the first 4 elems 
        "vpermq $0x00, %%ymm1, %%ymm1 \n\t"
        "vpshufb (%[pshuf_elem_reorg_op]), %%ymm1, %%ymm1 \n\t"
        
        


        // Iter 1
        "vmovdqu (%[lhs_ptr]), %%ymm0 \n\t" // move lhs to ymm0

        // lhs[0:32,:] * rhs[0] 
        // Duplicate rhs[0]
        "vpermq $0x00, %%ymm1, %%ymm2 \n\t"
        // Elem-wise mul and add to accum 
        "vpsignb %%ymm0, %%ymm2, %%ymm2 \n\t"
        "vpaddb %%ymm2, %%ymm4, %%ymm4 \n\t"
        
        // lhs[0:32,:] * rhs[1] 
        // Duplicate rhs[1]
        "vpermq $0x55, %%ymm1, %%ymm2 \n\t"
        // Elem-wise mul and add to accum 
        "vpsignb %%ymm0, %%ymm2, %%ymm2 \n\t"
        "vpaddb %%ymm2, %%ymm5, %%ymm5 \n\t"

        // lhs[0:32,:] * rhs[2] 
        // Duplicate rhs[2]
        "vpermq $0xaa, %%ymm1, %%ymm2 \n\t"
        // Elem-wise mul and add to accum 
        "vpsignb %%ymm0, %%ymm2, %%ymm2 \n\t"
        "vpaddb %%ymm2, %%ymm6, %%ymm6 \n\t"
        
        // lhs[0:32,:] * rhs[3] 
        // Duplicate rhs[3]
        "vpermq $0xff, %%ymm1, %%ymm2 \n\t"
        // Elem-wise mul and add to accum 
        "vpsignb %%ymm0, %%ymm2, %%ymm2 \n\t"
        "vpaddb %%ymm2, %%ymm7, %%ymm7 \n\t"

        // Iter 2
        "vmovdqu 0x20(%[lhs_ptr]), %%ymm0 \n\t" // move lhs to ymm0

        // lhs[32:64,:] * rhs[0] 
        // Duplicate rhs[0]
        "vpermq $0x00, %%ymm1, %%ymm2 \n\t"
        // Elem-wise mul and add to accum 
        "vpsignb %%ymm0, %%ymm2, %%ymm2 \n\t"
        "vpaddb %%ymm2, %%ymm8, %%ymm8 \n\t"
        
        // lhs[32:64,:] * rhs[1] 
        // Duplicate rhs[1]
        "vpermq $0x55, %%ymm1, %%ymm2 \n\t"
        // Elem-wise mul and add to accum 
        "vpsignb %%ymm0, %%ymm2, %%ymm2 \n\t"
        "vpaddb %%ymm2, %%ymm9, %%ymm9 \n\t"

        // lhs[32:64,:] * rhs[2] 
        // Duplicate rhs[2]
        "vpermq $0xaa, %%ymm1, %%ymm2 \n\t"
        // Elem-wise mul and add to accum 
        "vpsignb %%ymm0, %%ymm2, %%ymm2 \n\t"
        "vpaddb %%ymm2, %%ymm10, %%ymm10 \n\t"
        
        // lhs[32:64,:] * rhs[3] 
        // Duplicate rhs[3]
        "vpermq $0xff, %%ymm1, %%ymm2 \n\t"
        // Elem-wise mul and add to accum 
        "vpsignb %%ymm0, %%ymm2, %%ymm2 \n\t"
        "vpaddb %%ymm2, %%ymm11, %%ymm11 \n\t"

        // Iter 3
        "vmovdqu 0x40(%[lhs_ptr]), %%ymm0 \n\t" // move lhs to ymm0

        // lhs[64:96,:] * rhs[0] 
        // Duplicate rhs[0]
        "vpermq $0x00, %%ymm1, %%ymm2 \n\t"
        // Elem-wise mul and add to accum 
        "vpsignb %%ymm0, %%ymm2, %%ymm2 \n\t"
        "vpaddb %%ymm2, %%ymm12, %%ymm12 \n\t"
        
        // lhs[64:96,:] * rhs[1] 
        // Duplicate rhs[1]
        "vpermq $0x55, %%ymm1, %%ymm2 \n\t"
        // Elem-wise mul and add to accum 
        "vpsignb %%ymm0, %%ymm2, %%ymm2 \n\t"
        "vpaddb %%ymm2, %%ymm13, %%ymm13 \n\t"

        // lhs[64:96,:] * rhs[2] 
        // Duplicate rhs[2]
        "vpermq $0xaa, %%ymm1, %%ymm2 \n\t"
        // Elem-wise mul and add to accum 
        "vpsignb %%ymm0, %%ymm2, %%ymm2 \n\t"
        "vpaddb %%ymm2, %%ymm14, %%ymm14 \n\t"
        
        // lhs[64:96,:] * rhs[3] 
        // Duplicate rhs[3]
        "vpermq $0xff, %%ymm1, %%ymm2 \n\t"
        // Elem-wise mul and add to accum 
        "vpsignb %%ymm0, %%ymm2, %%ymm2 \n\t"
        "vpaddb %%ymm2, %%ymm15, %%ymm15 \n\t"

        // move forward the pointer to rhs and lhs 
        "addq $0x60, %[lhs_ptr] \n\t"
        "addq $0x04, %[rhs_ptr] \n\t"
        
        // shift 4 byte left to get next block of result
        // "vpsrldq $0x04, %%ymm3, %%ymm3 \n\t"

        // "subq $1, %%r15 \n\t"
        "subq $1, %%r14 \n\t"
        "ja outerLoop%= \n\t"
        // outerLoop ends here
        
        //"vmovdqu %%ymm15, (%[arr_256]) \n\t"
        
        // Set registers for destination
        "movq  %[dst_col_stride_q], %%r12\n\t"  // stride is r12
        //"shlq $2, %%r12\n\t"                    // set stride dword
        "leaq (%%r12,%%r12,0x2), %%r13\n\t"     // load stride aligned r13
        "test %[start_depth], %[start_depth] \n\t"

        // Storing result
        "jz storeDst%= \n\t"

        "vpaddd 0x00(%[dst_ptr]), %%ymm4, %%ymm4 \n\t"    // rhs0
        "vpaddd 0x20(%[dst_ptr]), %%ymm8, %%ymm8 \n\t"    // rhs0
        "vpaddd 0x40(%[dst_ptr]), %%ymm12, %%ymm12 \n\t"  // rhs0

        "vpaddd 0x00(%[dst_ptr], %%r12, 1) , %%ymm5, %%ymm5   \n\t"  // rhs1
        "vpaddd 0x20(%[dst_ptr], %%r12, 1) , %%ymm9, %%ymm9   \n\t"  // rhs1
        "vpaddd 0x40(%[dst_ptr], %%r12, 1) , %%ymm13, %%ymm13 \n\t"  // rhs1

        "vpaddd 0x00(%[dst_ptr], %%r12, 2) , %%ymm6, %%ymm6   \n\t"  // rhs2
        "vpaddd 0x20(%[dst_ptr], %%r12, 2) , %%ymm10, %%ymm10 \n\t"  // rhs2
        "vpaddd 0x40(%[dst_ptr], %%r12, 2) , %%ymm14, %%ymm14 \n\t"  // rhs2

        "vpaddd 0x00(%[dst_ptr], %%r13, 1) , %%ymm7, %%ymm7   \n\t"  // rhs3
        "vpaddd 0x20(%[dst_ptr], %%r13, 1) , %%ymm11, %%ymm11 \n\t"  // rhs3
        "vpaddd 0x40(%[dst_ptr], %%r13, 1) , %%ymm15, %%ymm15 \n\t"  // rhs3

        "storeDst%=:\n\t"

        "vmovdqu %%ymm4, 0x00(%[dst_ptr])            \n\t"  // rhs0
        "vmovdqu %%ymm8, 0x20(%[dst_ptr])            \n\t"  // rhs0
        "vmovdqu %%ymm12, 0x40(%[dst_ptr])           \n\t"  // rhs0

        "vmovdqu %%ymm5, 0x00(%[dst_ptr], %%r12, 1)  \n\t"  // rhs1
        "vmovdqu %%ymm9, 0x20(%[dst_ptr], %%r12, 1)  \n\t"  // rhs1
        "vmovdqu %%ymm13, 0x40(%[dst_ptr], %%r12, 1) \n\t"  // rhs1

        "vmovdqu %%ymm6, 0x00(%[dst_ptr], %%r12, 2)  \n\t"  // rhs2
        "vmovdqu %%ymm10, 0x20(%[dst_ptr], %%r12, 2) \n\t"  // rhs2
        "vmovdqu %%ymm14, 0x40(%[dst_ptr], %%r12, 2) \n\t"  // rhs2

        "vmovdqu %%ymm7, 0x00(%[dst_ptr], %%r13, 1)  \n\t"  // rhs3
        "vmovdqu %%ymm11, 0x20(%[dst_ptr], %%r13, 1) \n\t"  // rhs3
        "vmovdqu %%ymm15, 0x40(%[dst_ptr], %%r13, 1) \n\t"  // rhs3


        // Debug log
        //"movq  %%r14, (%[tmp]) \n\t"
        :  // outputs
        [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
        [dst_ptr] "+r"(dst_ptr), [arr_256] "+r"(arr_256),
        [pshuf_elem_reorg_op] "+r"(pshuf_elem_reorg_op),
        [tmp] "+r"(tmp)
        :  // inputs
        [start_depth] "r"(start_depth), [dst_col_stride_q] "r"(dst_col_stride_q),
        [run_depth_cells] "r"(run_depth_cells)
        :  // clobbers
        "cc", "memory", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7",
        "%ymm8", "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15", "%r12",
        "%r13", "%r14", "r15");
    //std::cout << "tmp is " << *tmp << "\n";
    //PrintContent(arr_256, "ymm1");
    //assert(1 == 0);
  }
};


#endif

}  // namespace gemmlowp

#endif  // GEMMLOWP_INTERNAL_KERNEL_AVX_SIGN_INT8_ACC_H_
