#ifndef _LAYER_SP_HPP_
#define _LAYER_SP_HPP_

#include "dnnl.hpp"
#include "immintrin.h"
#include <vector>
using namespace dnnl;
typedef struct {
  __m256 first;
  __m256 second;
  __m256 third;
} Triple_t;

typedef struct {
  __m256 *first;
  __m256 *second;
  __m256 *third;
} M_t;


 inline Triple_t load_triple(float* ptr, int pos, int image_size) {
    Triple_t res;

    res.first = _mm256_load_ps( &ptr[pos]);
    res.second = _mm256_load_ps(&ptr[pos+image_size]);
    res.third = _mm256_load_ps( &ptr[pos+2*image_size]);
    return res;
}

 inline void dump_triple(Triple_t& input, float* ptr, int pos, int image_size) {
    _mm256_stream_ps( &ptr[pos], input.first);
    _mm256_stream_ps(&ptr[pos+image_size], input.second);
    _mm256_stream_ps( &ptr[pos+2*image_size], input.third);
}

 inline void dump_triple_none(Triple_t& input, float* ptr, int pos, int image_size) {
}

 inline Triple_t decrypt(M_t& m, Triple_t& input) {
    Triple_t res;
    __m256 fs = _mm256_mul_ps(m.first[0], input.first);
    __m256 ss = _mm256_mul_ps(m.first[1], input.second);
    __m256 ts = _mm256_mul_ps(m.first[2], input.third);
    res.first = _mm256_add_ps(fs, ss);
    res.first = _mm256_add_ps(res.first, ts);

    fs = _mm256_mul_ps(m.second[0], input.first);
    ss = _mm256_mul_ps(m.second[1], input.second);
    ts = _mm256_mul_ps(m.second[2], input.third);
    res.second = _mm256_add_ps(fs, ss);
    res.second = _mm256_add_ps(res.second, ts);
    return res;
}

 inline Triple_t norm_triple(Triple_t& input, __m256 mean_vec, __m256 std_vec, __m256 gamma, __m256 beta) {
    Triple_t res;
    res.first = _mm256_sub_ps(input.first, mean_vec);
    res.first = _mm256_add_ps(_mm256_mul_ps(_mm256_div_ps(res.first, std_vec), gamma), beta);

    res.second = _mm256_sub_ps(input.second, mean_vec);
    res.second = _mm256_add_ps(_mm256_mul_ps(_mm256_div_ps(res.second, std_vec), gamma), beta);
    return res;
}   

 inline Triple_t encrypt(M_t& m, Triple_t& input) {
    Triple_t res;
    __m256 fs = _mm256_mul_ps(m.first[0], input.first);
    __m256 ss = _mm256_mul_ps(m.first[1], input.second);
    __m256 ts = _mm256_mul_ps(m.first[2], input.third);
    res.first = _mm256_add_ps(fs, ss);
    res.first = _mm256_add_ps(res.first, ts);

    fs = _mm256_mul_ps(m.second[0], input.first);
    ss = _mm256_mul_ps(m.second[1], input.second);
    ts = _mm256_mul_ps(m.second[2], input.third);
    res.second = _mm256_add_ps(fs, ss);
    res.second = _mm256_add_ps(res.second, ts);

    fs = _mm256_mul_ps(m.third[0], input.first);
    ss = _mm256_mul_ps(m.third[1], input.second);
    ts = _mm256_mul_ps(m.third[2], input.third);
    res.third = _mm256_add_ps(fs, ss);
    res.third = _mm256_add_ps(res.third, ts);
    return res;
}

 inline Triple_t relu(Triple_t& input) {
    Triple_t res;
    __m256 zero_v = _mm256_set1_ps((float)(0));

    res.first  = _mm256_max_ps(input.first, zero_v);
    res.second = _mm256_max_ps(input.second, zero_v);

    return res;
}

 inline Triple_t relu_back(Triple_t& grad, Triple_t& input) {
    Triple_t res;
    __m256 zero_v = _mm256_set1_ps((float)(0));

    __m256 mask0 = _mm256_cmp_ps(zero_v, input.first, 0x11);
    __m256 mask1 = _mm256_cmp_ps(zero_v, input.second, 0x11);
    res.first  = _mm256_and_ps(mask0, grad.first);
    res.second = _mm256_and_ps(mask1, grad.second);
    return res;
}

 inline Triple_t grad_none(Triple_t& grad, Triple_t& input) {
    return grad;
}

 inline Triple_t none(Triple_t& input) {
    return input;
}

 inline Triple_t encrypt_none(M_t& m, Triple_t& input) {
    return input;
}

class LayerSp {
  public:
    virtual void forward(float*, float*, bool)=0;
    virtual void update_backward(memory::desc)=0;
    virtual void backward(float*, float*)=0;
    virtual int out_size()=0;
    virtual int input_size()=0;
    virtual int type() = 0;
    virtual memory::desc dst_desc()=0;
    virtual memory::desc diff_src_desc()=0;
    void dump_src_input(void* cipher, void* plain, int size);
    void load_src_input(void* plain, void* cipher, int size);
    std::vector<void*> iv_ptr;
    std::vector<void*> mac_ptr;
    bool  sharding=true;
    float* saved_src_;
};

#endif