#ifndef _RESNET_ACTIVATION_SP_H_
#define _RESNET_ACTIVATION_SP_H_

#include <assert.h>

#include <math.h>
#include <numeric>
#include <string>
#include "example_utils.hpp"
#include "Enclave_t.h"
#include "layer.hpp"
#include "pool_sp.h"
using namespace dnnl;
using namespace std;
#include "immintrin.h"

// only used for darKnight to combine operators
class ResnetActivation : public LayerSp{
  public:
    ResnetActivation(string act_mode, 
                     int in_size[4], 
                     int out_size[4], 
                     int pool_window[2], 
                     int pool_stride[2], 
                     float eps,
                     float momentum,
                     float* bias_data,
                     memory::desc src_tags, 
                     engine eng, 
                     stream s);

    void forward(float*, float*, bool) {};
    void forward_sp(float* src, 
                    float* dst, 
                    float* mean_extern, 
                    bool is_first, 
                    bool final_enc);
    void update_backward(memory::desc);
    void backward(float*, float*);
    int type() {return 0;}
    int out_size()   {return this->out_size_;}
    int input_size() {return this->in_size_;}
    
    memory::desc diff_src_desc() {return  nhwc_;}
    memory::desc dst_desc() {return nhwc_out_;}
    ~ResnetActivation();
    M_t bm;
    M_t um;
    M_t gm;
    M_t igm;
  private:

    bool diff_dst_reorder = false;
    engine eng_;
    stream s_;
    int in_size_;
    int out_size_;
    int channel_;
    float* saved_work_=nullptr;
    float* center_aligned_;
    float* scale_aligned_;
    float* center_grad_aligned_;
    float* scale_grad_aligned_;
    float* center_ptr_;
    float* scale_ptr_;
    float* center_grad_ptr_;
    float* scale_grad_ptr_;
    float* bias_ptr_;
    float* bias_aligned_;
    float* bias_grad_ptr_;
    float* bias_grad_aligned_;
    float* mean_ptr_;
    float* mean_aligned_;
    float* std_ptr_;
    float* std_aligned_;
    float* act_src_ptr_;
    float* act_src_aligned_;
    float* batch_src_ptr_;
    float* batch_src_aligned_;
    float* work_ptr_;
    float* work_align_;
    float eps_;
    float momentum_;
    memory::desc src_tag_;
    int internal_batch_size_;
    int batch_size_;
    bool get_pool_ = false;
    bool get_act_ = false;
    bool get_bn_ = false;
    std::string act_mode_;
    PoolSp* pool_ = nullptr;
    memory::desc nhwc_;
    memory::desc nhwc_out_;
    
}; // class ResnetActivation

#endif
