#ifndef SGXDNN_BATCHNORM_H_
#define SGXDNN_BATCHNORM_H_

#include "assert.h"
#include <iostream>
#include <string>
#include "immintrin.h"
#include "layer.hpp"
#include "Enclave.h"
#include "sgx_tcrypto.h"
#include "Crypto.h"

extern float* temp_buffer;
extern float* temp_buffer2;
extern float* temp_buffer3;
extern float* temp_buffer4;
extern float* bm_buf;
extern float* gm_buf;
extern float* igm_buf;
extern float* um_buf;

inline __m256 none(__m256 in1, __m256 in2) {
            return in1;
}

inline __m256 relu6(__m256 in1, __m256 in2) {
    __m256 zero_v = _mm256_set1_ps((float)(0));
    __m256 six_v = _mm256_set1_ps((float)(6));
    return _mm256_min_ps(_mm256_max_ps(in1, zero_v), six_v);
}

inline __m256 relu6_back(__m256 input, __m256 grad) {
    __m256 zero_v = _mm256_set1_ps((float)(0));
    __m256 six_v = _mm256_set1_ps((float)(6));

    __m256 mask =  _mm256_and_ps(_mm256_cmp_ps(zero_v, input, 0x11), _mm256_cmp_ps(input, six_v, 0x11));
    return _mm256_and_ps(mask, grad);
}

inline __m256 none_back(__m256 input, __m256 grad) {
    return input;
}

inline __m256 skip_add(__m256 in1, __m256 in2) {
    return _mm256_add_ps(in1, in2);
}


inline __m256 skip_pass(__m256 x, __m256 y, __m256 z, float m0, float m1, float m2) {
    return x;
}

inline __m256 skip_unblind(__m256 x, __m256 y, __m256 z, float m0, float m1, float m2) {
    const __m256 m00 = _mm256_set1_ps(m0);
    const __m256 m01 = _mm256_set1_ps(m1);
    const __m256 m02 = _mm256_set1_ps(m2);


    __m256 fs = _mm256_mul_ps(m00, x);
    __m256 ss = _mm256_mul_ps(m01, y);
    __m256 ts = _mm256_mul_ps(m02, z);
    __m256 res0 = _mm256_add_ps(fs, ss);
    res0 = _mm256_add_ps(res0, ts);
    return res0;
}
using namespace tensorflow;

namespace SGXDNN
{
    template <typename T> class BatchNormSp : public Layer<T>
    {
    public:
        explicit BatchNormSp(const std::string& name,
                         array4d input_shape):
        Layer<T>(name, input_shape)
        {
            const int input_rows = input_shape[1];
            const int input_cols = input_shape[2];
            const int input_depth = input_shape[3];

            output_shape_ = {2, input_rows, input_cols, input_depth};
            output_size_ =   input_rows * input_cols * input_depth;
            moving_mean_ = 0.0;
            moving_std_  = 0.0;
            center_t_ = Tensor<float, 1>(input_depth);
            scale_t_ = Tensor<float, 1>(input_depth);
            center_grad_t_ = Tensor<float, 1>(input_depth);
            scale_grad_t_ = Tensor<float, 1>(input_depth);
            center_ = center_t_.data();
            scale_ = scale_t_.data();
            center_grad_ = center_grad_t_.data();
            scale_grad_ = scale_grad_t_.data();

            for (int i = 0; i < input_depth; i++) {
                center_[i] = 0.0;
                scale_[i]  = 1.0;
                center_grad_[i] = 0.0;
                scale_grad_[i] = 0.0;
            }

            // input depth
            if (input_depth % 8 != 0 || input_depth % 8 != 0) {
                    printf("batch norm len error\n");
                    assert(false);
            }
            skip_func = skip_pass;

            // choosing the skip function
            if (name == std::string("bnadd"))
                skip_func = skip_unblind;

        }

        ~BatchNormSp() {
            delete [] center_;
            delete [] scale_;

        }
       

        inline __m256 blind_scale(__m256 x, __m256 y, float m0, float m1, float m2) {
            const __m256 m00 = _mm256_set1_ps(m0);
            const __m256 m01 = _mm256_set1_ps(m1);
            const __m256 m02 = _mm256_set1_ps(m2);

            __m256 z = _mm256_set1_ps((float) 0.0);
            __m256 fs = _mm256_mul_ps(m00, x);
            __m256 ss = _mm256_mul_ps(m01, y);
            __m256 ts = _mm256_mul_ps(m02, z);
            __m256 res0 = _mm256_add_ps(fs, ss);
            res0 = _mm256_add_ps(res0, ts);
            return res0;

        }

        void update_params(bool privacy, float eps, float momentum) {
            privacy_ = privacy;
            eps_ = eps;
            momentum_ = momentum;
        }

        array4d output_shape() override
        {
            return output_shape_;
        }

        int output_size() override
        {
            return output_size_;
        }

        float std(float* inp, int batch_size, int mini_batch_size, float mean, float* um_buf1) {
            int skip_size = this->output_size_ * mini_batch_size;
            int mini_batch_len = (this->privacy_) ? mini_batch_size - 1 : mini_batch_size;
            int num_ele_mini_batch = this->output_size_ * mini_batch_len;

            if (num_ele_mini_batch % 8 != 0) {
                printf("batch norm num elements in a minibatch is wrong %d\n", num_ele_mini_batch);
                assert(false);
            }
            float std_res_vec[16];
            float* std_res_vec_aligned = ALIGN32(std_res_vec);

            // outter for loop size
            int iter = batch_size / mini_batch_size;

            __m256 std_res = _mm256_set1_ps((float)(0));
            __m256 mean_vec = _mm256_set1_ps((float)(mean));
            
            // for loop to compute results
            // call std functions based on privacy condition
            float* inp_ptr = inp;
            if (privacy_) {
                // loading unblind matrix
                const __m256 m00 = _mm256_set1_ps(um_buf1[0]);
                const __m256 m01 = _mm256_set1_ps(um_buf1[1]);
                const __m256 m02 = _mm256_set1_ps(um_buf1[2]);
                const __m256 m10 = _mm256_set1_ps(um_buf1[3]);
                const __m256 m11 = _mm256_set1_ps(um_buf1[4]);
                const __m256 m12 = _mm256_set1_ps(um_buf1[5]);

                // iterate through all virtual batch
                for (int v_iter = 0; v_iter < iter; v_iter++) {
                    // setup pointers properly
                    float* first_ptr = inp_ptr;
                    float* second_ptr = inp_ptr + this->output_size_;
                    float* third_ptr  = inp_ptr + this->output_size_ * 2;
                    for (int j = 0; j < this->output_size_; j+=8) {
                        const __m256 f = _mm256_load_ps(&first_ptr[j]);
                        const __m256 s = _mm256_load_ps(&second_ptr[j]);
                        const __m256 t = _mm256_load_ps(&third_ptr[j]);
                        
                        // gaussian elimination
                        __m256 fs = _mm256_mul_ps(m00, f);
                        __m256 ss = _mm256_mul_ps(m01, s);
                        __m256 ts = _mm256_mul_ps(m02, t);
                        __m256 res0 = _mm256_add_ps(fs, ss);
                        res0 = _mm256_add_ps(res0, ts);
                        
                        // std compuation
                        __m256 temp = _mm256_sub_ps(res0, mean_vec);
                        temp = _mm256_mul_ps(temp, temp);
                        std_res = _mm256_add_ps(temp, std_res);

                        // gaussian elimination
                        fs = _mm256_mul_ps(m10, f);
                        ss = _mm256_mul_ps(m11, s);
                        ts = _mm256_mul_ps(m12, t);
                        __m256 res1 = _mm256_add_ps(fs, ss);
                        res1 = _mm256_add_ps(res1, ts);

                        // std compuatation
                        temp = _mm256_sub_ps(res1, mean_vec);
                        temp = _mm256_mul_ps(temp, temp);
                        std_res = _mm256_add_ps(temp, std_res);
                    }
                    inp_ptr += skip_size;
                }
            } else {
                for (int v_iter = 0; v_iter < iter; v_iter++) {
                    for (int j = 0; j < num_ele_mini_batch; j+=8) {
                        __m256 num = _mm256_load_ps(&inp_ptr[j]);

                        // std res = sum((n - mean) ^ 2)
                        __m256 temp = _mm256_sub_ps(num, mean_vec);
                        temp = _mm256_mul_ps(temp, temp);
                        std_res = _mm256_add_ps(temp, std_res);
                    }
                    inp_ptr += skip_size;
                }
            }
            // output to vector
            _mm256_stream_ps(std_res_vec_aligned, std_res);

            float res = 0.0;
            for (int i = 0; i < 8; i++)
                res += std_res_vec_aligned[i];

            int effective_total_batch = batch_size * mini_batch_len / mini_batch_size;
            return sqrt(res / (effective_total_batch * this->output_size_) + this->eps_);
        }

        __m256 matrix_mul_scale(__m256 f, __m256 s, __m256 t, 
                                const __m256 m00, 
                                const __m256 m01, 
                                const __m256 m02
                                ) {
            __m256 fs = _mm256_mul_ps(m00, f);
            __m256 ss = _mm256_mul_ps(m01, s);
            __m256 ts = _mm256_mul_ps(m02, t);
            __m256 res0 = _mm256_add_ps(fs, ss);
            res0 = _mm256_add_ps(res0, ts);

            return res0;
        }

        void norm_func(float* inp, float* skip_input, float* out, float* act_src, float mean, 
                       float std, int batch_size,
                       int mini_batch_size, __m256 (* act_func)(__m256, __m256), float* um_buf1, bool save_inter=false) {
            int skip_size = this->output_size_ * mini_batch_size;
            int mini_batch_len = (this->privacy_) ? mini_batch_size - 1 : mini_batch_size;
            int num_ele_mini_batch = this->output_size_ * mini_batch_len;


            if (num_ele_mini_batch % 8 != 0) {
                printf("batch norm num elements in a minibatch is wrong %d\n", num_ele_mini_batch);
                assert(false);
            }

            __m256 mean_vec = _mm256_set1_ps((float)(mean));
            __m256 std_vec  = _mm256_set1_ps((float)(1.0/std));

            int iter = batch_size / mini_batch_size;


            if (privacy_) {
                // loading unblind matrix
                const __m256 m00 = _mm256_set1_ps(um_buf1[0]);
                const __m256 m01 = _mm256_set1_ps(um_buf1[1]);
                const __m256 m02 = _mm256_set1_ps(um_buf1[2]);
                const __m256 m10 = _mm256_set1_ps(um_buf1[3]);
                const __m256 m11 = _mm256_set1_ps(um_buf1[4]);
                const __m256 m12 = _mm256_set1_ps(um_buf1[5]);    

                float* inp_ptr = inp;
                float* out_ptr = out;
                float* skip_ptr = skip_input;
                float* act_src_ptr = act_src;

                for (int v_iter = 0; v_iter < iter; v_iter++) {
                    // pointer loading
                    // input source
                    float* first_ptr = inp_ptr;
                    float* second_ptr = inp_ptr + this->output_size_;
                    float* third_ptr  = inp_ptr + this->output_size_ * 2;
                    
                    // skip input source
                    float* first_skip_ptr =  skip_ptr;
                    float* second_skip_ptr =  skip_ptr + this->output_size_;
                    float* third_skip_ptr =  skip_ptr + this->output_size_ * 2;
                    
                    // activation source
                    float* first_act_ptr = act_src_ptr;
                    float* second_act_ptr = act_src_ptr + this->output_size_;
                    float* third_act_ptr = act_src_ptr + this->output_size_ * 2;

                    // output ptr
                    float* first_out_ptr = out_ptr;
                    float* second_out_ptr = out_ptr + this->output_size_;
                    float* third_out_ptr = out_ptr + this->output_size_ * 2;

                    for (int j = 0; j < this->output_size_; j+=8) {
                        // channel pointer beta and gamma
                        int channel_ptr = j % this->output_shape_[3];
                        __m256 beta  = _mm256_load_ps(&center_[channel_ptr]);
                        __m256 gamma = _mm256_load_ps(&scale_[channel_ptr]);
                        const __m256 f = _mm256_load_ps(&first_ptr[j]);
                        const __m256 s = _mm256_load_ps(&second_ptr[j]);
                        const __m256 t = _mm256_load_ps(&third_ptr[j]);

                        const __m256 sk_f = _mm256_load_ps(&first_skip_ptr[j]);
                        const __m256 sk_s = _mm256_load_ps(&second_skip_ptr[j]);
                        const __m256 sk_t = _mm256_load_ps(&third_skip_ptr[j]);
                        
                        // unblinding first
                        // gaussian elimination
                        __m256 fs = _mm256_mul_ps(m00, f);
                        __m256 ss = _mm256_mul_ps(m01, s);
                        __m256 ts = _mm256_mul_ps(m02, t);
                        __m256 res0 = _mm256_add_ps(fs, ss);
                        res0 = _mm256_add_ps(res0, ts);

                        // gaussian elimination
                        fs = _mm256_mul_ps(m10, f);
                        ss = _mm256_mul_ps(m11, s);
                        ts = _mm256_mul_ps(m12, t);
                        __m256 res1 = _mm256_add_ps(fs, ss);
                        res1 = _mm256_add_ps(res1, ts);

                        // normalizating
                        __m256 temp0 = _mm256_sub_ps(res0, mean_vec);
                        temp0 = _mm256_add_ps(_mm256_mul_ps(_mm256_mul_ps(temp0, std_vec), gamma), beta);
                        

                        __m256 temp1 = _mm256_sub_ps(res1, mean_vec);
                        temp1 = _mm256_add_ps(_mm256_mul_ps(_mm256_mul_ps(temp1, std_vec), gamma), beta);                       
                        
                        // Encrypting normalization output and  write out
                        __m256 blind0 = blind_scale(temp0, temp1, bm_buf[0], bm_buf[1], bm_buf[2]);
                        __m256 blind1 = blind_scale(temp0, temp1, bm_buf[3], bm_buf[4], bm_buf[5]);
                        __m256 blind2 = blind_scale(temp0, temp1, bm_buf[6], bm_buf[7], bm_buf[8]);
                        _mm256_stream_ps(&first_act_ptr[j], blind0);
                        _mm256_stream_ps(&second_act_ptr[j], blind1);
                        _mm256_stream_ps(&third_act_ptr[j], blind2);

                        // skip function chosen when this layer is created
                        __m256 skip0 = skip_func(sk_f, sk_s, sk_t, um_buf[0], um_buf[1], um_buf[2]);
                        __m256 skip1 = skip_func(sk_s, sk_f, sk_t, um_buf[3], um_buf[4], um_buf[5]);

                        // applying activation
                        temp0 = act_func(temp0, skip0);
                        temp1 = act_func(temp1, skip1);
                        
                        // blind and output
                        blind0 = blind_scale(temp0, temp1, bm_buf[0], bm_buf[1], bm_buf[2]);
                        blind1 = blind_scale(temp0, temp1, bm_buf[3], bm_buf[4], bm_buf[5]);
                        blind2 = blind_scale(temp0, temp1, bm_buf[6], bm_buf[7], bm_buf[8]);
                        _mm256_stream_ps(&first_out_ptr[j], blind0);
                        _mm256_stream_ps(&second_out_ptr[j], blind1);
                        _mm256_stream_ps(&third_out_ptr[j], blind2);
                    }
                    inp_ptr += skip_size;
                    out_ptr += skip_size;
                    skip_ptr += skip_size; 
                    act_src_ptr += skip_size;
                }
            } else {
                float* inp_ptr = inp;
                float* out_ptr = out;
                float* skip_ptr = skip_input;
                float* act_src_ptr = act_src;
                for (int v_iter = 0; v_iter < iter; v_iter++) {
                    for (int j = 0; j < num_ele_mini_batch; j+=8) {
                        int channel_ptr = j % this->output_shape_[3];
                        __m256 beta  = _mm256_load_ps(&center_[channel_ptr]);

                        __m256 gamma = _mm256_load_ps(&scale_[channel_ptr]);
                        __m256 num = _mm256_load_ps(&inp_ptr[j]);
                        __m256 skip = _mm256_load_ps(&skip_ptr[j]);

                        __m256 temp = _mm256_sub_ps(num, mean_vec);
                        temp = _mm256_add_ps(_mm256_mul_ps(_mm256_mul_ps(temp, std_vec), gamma), beta);
                        if (save_inter)
                            _mm256_stream_ps(&act_src_ptr[j], temp);
                        temp = act_func(temp, skip);
                        _mm256_stream_ps(&out_ptr[j], temp);
                    }
                    inp_ptr += skip_size;
                    out_ptr += skip_size;
                    skip_ptr += skip_size;
                    act_src_ptr += skip_size;
                }
            }
        }

        void fwd_encl(float* output, float* input, bool is_train) {
            

        }

        void fwd(float* output, float* input, float* means, float* skip_input, float* act_relu,
                 float* um_buf1, int batch_size, int mini_batch_size, const char* act_mode) {
            
            // compute true mean for the mini batch
            float batch_mean = 0.0;
            batch_size_ = batch_size;
            mini_batch_size_ = mini_batch_size;
            if (privacy_) {
                int mean_len = batch_size;
                // iterate all batches
                for (int i = 0; i < mean_len; i+=mini_batch_size) {
                    int m0 = means[0];
                    int m1 = means[1];
                    int m2 = means[2];

                    // unblind the true mean
                    int tm0 = m0 * um_buf1[0] + m1 * um_buf1[1] + m2 * um_buf1[2];
                    int tm1 = m0 * um_buf1[3] + m1 * um_buf1[4] + m2 * um_buf1[5];
                    batch_mean += tm0 + tm1;
                }
                // aggregate mean
                batch_mean /= (mean_len / mini_batch_size * (mini_batch_size - 1));
            } else {
                int mean_len = batch_size;

                for (int i = 0; i < mean_len; i++)
                    batch_mean += means[i];
                batch_mean /= mean_len;
            }

            // save the mean
            saved_mean = batch_mean;
            // compute std
            float batch_std = this->std(input, batch_size, mini_batch_size, batch_mean, um_buf1);
            saved_std = batch_std;
            // add normalization + append activation when possible (relu6 or skip connection)
            mode_ = std::string(act_mode);

            // function pointer
            __m256 (* act_func)(__m256, __m256);
            bool save_inter = false;
            if (mode_ == "bnrelu") {
                act_func = relu6;
            } else if (mode_ == "bn") {
                act_func = none;
            } else if (mode_ == "bnadd") {
                act_func = skip_add;
                save_inter = true;
            }

            // call normalize function
            norm_func(input, skip_input, output, act_relu,
                      batch_mean, batch_std, batch_size, 
                      mini_batch_size, act_func, um_buf1, save_inter);
        }

        void bwd (float* grad_out, float* grad, float* inp, float* skip_src, float* act_src) {
            int skip_size = this->output_size_ * mini_batch_size_;
            int mini_batch_len = (this->privacy_) ? mini_batch_size_ - 1 : mini_batch_size_;
            int num_ele_mini_batch = this->output_size_ * mini_batch_len;

            int iter = batch_size_ / mini_batch_size_;

            float* inp_ptr = inp;
            float* grad_ptr = grad;
            float* act_ptr = act_src;
            float* grad_out_ptr = grad_out;

            __m256 (*act_back) (__m256, __m256);

            if (mode_ == "bnrelu") {
                act_back = relu6_back;
            } else {
                act_back = none_back;
            }
            const __m256 mean_m = _mm256_set1_ps(saved_mean);
            const __m256 std_m  = _mm256_set1_ps(1/saved_std);
            // darkNight mode
            if (privacy_) {

                for (int i = 0; i < iter; i++) {

                    // preparing pointers
                    float* first_act_ptr = act_ptr;
                    float* second_act_ptr = act_ptr + this->output_size_;
                    float* third_act_ptr = act_ptr + this->output_size_ * 2;

                    // gradient pointers
                    float* first_grad_ptr = grad_ptr;
                    float* second_grad_ptr = grad_ptr + this->output_size_;
                    float* third_grad_ptr = grad_ptr + this->output_size_ * 2;

                    // output ptr
                    float* first_out_ptr = grad_out_ptr;
                    float* second_out_ptr = grad_out_ptr + this->output_size_;
                    float* third_out_ptr = grad_out_ptr + this->output_size_ * 2;

                    for (int j = 0; j < this->output_size_; j+=8) {

                        const __m256 grad_m0 = _mm256_load_ps(&first_grad_ptr[j]);
                        const __m256 grad_m1 = _mm256_load_ps(&second_grad_ptr[j]);
                        const __m256 grad_m2 = _mm256_load_ps(&third_grad_ptr[j]);

                        const __m256 act_m0 = _mm256_load_ps(&first_act_ptr[j]);
                        const __m256 act_m1 = _mm256_load_ps(&second_act_ptr[j]);
                        const __m256 act_m2 = _mm256_load_ps(&third_act_ptr[j]);

                        __m256 grad0 = skip_unblind(grad_m0, grad_m1, grad_m2, igm_buf[0], igm_buf[1], igm_buf[3]);
                        __m256 grad1 = skip_unblind(grad_m0, grad_m1, grad_m2, igm_buf[3], igm_buf[4], igm_buf[5]);
                        __m256 act0 = skip_unblind(act_m0, act_m1, act_m2, um_buf[0], um_buf[1], um_buf[3]);
                        __m256 act1 = skip_unblind(act_m0, act_m1, act_m2, um_buf[3], um_buf[4], um_buf[5]);

                        // activation back
                        __m256 grad_norm0 = act_back(act0, grad0);
                        __m256 grad_norm1 = act_back(act1, grad1);

                        int channel_ptr = j % this->output_shape_[3];
                        __m256 beta  = _mm256_load_ps(&center_grad_[channel_ptr]);
                        __m256 gamma = _mm256_load_ps(&scale_grad_[channel_ptr]);

                        __m256 scale_m = _mm256_load_ps(&scale_[channel_ptr]);
                        __m256 center_m = _mm256_load_ps(&center_[channel_ptr]);

                        // gradients to parameters
                        __m256 input0 = _mm256_div_ps(_mm256_sub_ps(act0, center_m), scale_m);
                        __m256 input1 = _mm256_div_ps(_mm256_sub_ps(act1, center_m), scale_m);

                        __m256 grad_gamma = _mm256_add_ps(_mm256_mul_ps(grad_norm0, input0), gamma);
                        __m256 grad_beta  = _mm256_add_ps(grad_norm0, beta);
                        grad_gamma = _mm256_add_ps(_mm256_mul_ps(grad_norm1, input1), grad_gamma);
                        grad_beta  = _mm256_add_ps(grad_norm1, grad_beta);

                        // gradients to input
                        __m256 grad_out_m0 = _mm256_mul_ps(grad_norm0, _mm256_mul_ps(scale_m, std_m));
                        __m256 grad_out_m1 = _mm256_mul_ps(grad_norm1, _mm256_mul_ps(scale_m, std_m));

                        // blind and output
                        __m256 blind0 = blind_scale(grad_out_m0, grad_out_m1, bm_buf[0], bm_buf[1], bm_buf[2]);
                        __m256 blind1 = blind_scale(grad_out_m0, grad_out_m1, bm_buf[3], bm_buf[4], bm_buf[5]);
                        __m256 blind2 = blind_scale(grad_out_m0, grad_out_m1, bm_buf[6], bm_buf[7], bm_buf[8]);
                        //__m256 blind0 = blind_scale(grad_norm0, grad_norm1, bm_buf[0], bm_buf[1], bm_buf[2]);
                        //__m256 blind1 = blind_scale(grad_norm0, grad_norm1, bm_buf[3], bm_buf[4], bm_buf[5]);
                        //__m256 blind2 = blind_scale(grad_norm0, grad_norm1, bm_buf[6], bm_buf[7], bm_buf[8]);



                        // steaming gradients outside
                        _mm256_stream_ps(&center_grad_[channel_ptr], grad_beta);
                        _mm256_stream_ps(&scale_grad_[channel_ptr], grad_gamma);
                        _mm256_stream_ps(&first_out_ptr[j], blind0);
                        _mm256_stream_ps(&second_out_ptr[j], blind1);
                        _mm256_stream_ps(&third_out_ptr[j], blind2);

                    }
                    inp_ptr += skip_size;
                    grad_ptr += skip_size;
                    act_ptr += skip_size;
                    grad_out_ptr += skip_size;
                }

                return;
            }

            // self contained
            for (int i = 0; i < iter; i++) {
                for (int j = 0; j < num_ele_mini_batch; j+=8) {
                    __m256 act_m = _mm256_load_ps(&act_ptr[j]);
                    __m256 grad_m = _mm256_load_ps(&grad_ptr[j]);
                    __m256 input_m = _mm256_load_ps(&inp_ptr[j]);

                    int channel_ptr = j % this->output_shape_[3];
                    __m256 beta  = _mm256_load_ps(&center_grad_[channel_ptr]);
                    __m256 gamma = _mm256_load_ps(&scale_grad_[channel_ptr]);

                    __m256 scale_m = _mm256_load_ps(&scale_[channel_ptr]);

                    // activation back
                    __m256 grad_norm = act_back(act_m, grad_m);

                    // gradient to parameters
                    input_m = _mm256_mul_ps(_mm256_sub_ps(input_m, mean_m), std_m);
                    __m256 grad_gamma = _mm256_add_ps(_mm256_mul_ps(grad_norm, input_m), gamma);
                    __m256 grad_beta  = _mm256_add_ps(grad_norm, beta);

                    // gradient to input
                    __m256 grad_out_m = _mm256_mul_ps(grad_norm, _mm256_mul_ps(scale_m, std_m));
                    // stream gradients outside
                    _mm256_stream_ps(&grad_out_ptr[j], grad_out_m);
                    _mm256_stream_ps(&center_grad_[channel_ptr], grad_beta);
                    _mm256_stream_ps(&scale_grad_[channel_ptr], grad_gamma);
                }
                inp_ptr      += skip_size;
                grad_ptr     += skip_size;
                act_ptr      += skip_size;
                grad_out_ptr += skip_size;
            }
        }
    protected:

        TensorMap<T, 4> apply_impl(TensorMap<T, 4> input, void* device_ptr = NULL, bool release_input=true) override
        {
            //const int batch = input.dimension(0);
            //output_shape_[2] = batch;
            //return input.reshape(output_shape_);
            return input;
        }


        TensorMap<T, 4> fwd_verify_impl(TensorMap<T, 4> input, float** aux_data, int linear_idx, void* device_ptr = NULL, bool release_input = true) override
        {
            return input;
        }

        array4d output_shape_;
        int output_size_;
        float moving_mean_;
        float moving_std_;
        float* center_;
        float* scale_;
        float* center_grad_;
        float* scale_grad_;
        float privacy_;
        float eps_;
        float momentum_;
        float saved_mean;
        float saved_std;
        int batch_size_;
        int mini_batch_size_;
        std::string mode_;
        __m256 (*skip_func) (__m256, __m256, __m256, float, float, float);
        Tensor<float, 1> center_t_;
        Tensor<float, 1> scale_t_;
        Tensor<float, 1> center_grad_t_;
        Tensor<float, 1> scale_grad_t_;
    };
}

#endif
