#ifndef _CONV2D_H_
#define _CONV2D_H_

#include <assert.h>

#include <math.h>
#include <numeric>
#include <string>

#include "Enclave_t.h"
#include "layer.hpp"
using namespace dnnl;
double get_time_force();
double get_elapsed_time(double start, double end);

static memory::dim product(const memory::dims &dims) {
    return std::accumulate(dims.begin(), dims.end(), (memory::dim)1,
            std::multiplies<memory::dim>());
}

class Conv2d : public LayerSp{
  public:
    
    Conv2d(int conv_src_sz[4], int conv_dst_sz[4], int conv_weight_sz[4],
           int conv_strides_sz[2], int conv_padding_sz[2], memory::desc src_tags,
           engine eng, stream s, float* weight_data, float* bias_data, bool training=true, bool is_first=false 
          ) {
        bool dump_src = false;

        // namespaces
        using tag = memory::format_tag;
        using dt = memory::data_type;

        // creating memory descriptors
        memory::data_type dtype = memory::data_type::f32;

        memory::dims conv_src_tz     = {conv_src_sz[0], conv_src_sz[3], conv_src_sz[1], conv_src_sz[2]};
        memory::dims conv_weights_tz = {conv_weight_sz[3], conv_weight_sz[2], conv_weight_sz[0], conv_weight_sz[1]};
        memory::dims conv_dst_tz     = {conv_dst_sz[0], conv_dst_sz[3], conv_dst_sz[1], conv_dst_sz[2]};
        memory::dims conv_bias_tz    = {conv_weight_sz[3]};
        memory::dims conv_strides    = {conv_strides_sz[0], conv_strides_sz[1]};
        memory::dims conv_padding    = {conv_padding_sz[0], conv_padding_sz[1]};
        memory::desc weights_tag     = {{conv_weights_tz}, dt::f32, tag::hwio};
        memory::desc src_tag         = src_tags;
        printf("conv src %d, %d, %d, %d", conv_src_sz[0], conv_src_sz[1], conv_src_sz[2], conv_src_sz[3]);
        printf("conv dst %d, %d, %d, %d", conv_dst_sz[0], conv_dst_sz[1], conv_dst_sz[2], conv_dst_sz[3]);
        printf("conv ker %d, %d, %d, %d", conv_weight_sz[0], conv_weight_sz[1], conv_weight_sz[2], conv_weight_sz[3]);
        printf("conv stride %d, %d", conv_strides_sz[0], conv_strides_sz[1]);
        printf("conv padding %d, %d", conv_padding_sz[0], conv_padding_sz[1]);

        conv_src_tz_                 = conv_src_tz;
        conv_bias_tz_                = conv_bias_tz;
        conv_weights_tz_             = conv_weights_tz;
        conv_dst_tz_                 = conv_dst_tz;
        conv_padding_                = conv_padding;
        conv_strides_                = conv_strides;
        usr_src_desc_                = src_tag;
        usr_weight_desc_             = weights_tag;


        int weight_size = product(conv_weights_tz);
        output_size_    = product(conv_dst_tz);
        input_size_     = product(conv_src_tz);
        //bias_data_      = new float[conv_weight_sz[3]];
        //bias_diff_      = new float[conv_weight_sz[3]];
 

        //weight_data_    = new float[weight_size];
        //weight_diff_    = new float[weight_size];
        ocall_extern_alloc((void**) &weight_data_, 
                                   weight_size*sizeof(float*));
        ocall_extern_alloc((void**) &weight_diff_, 
                                   weight_size*sizeof(float*));

        ocall_extern_alloc((void**) &bias_data_, 
                                   conv_weight_sz[3]*sizeof(float*));
        ocall_extern_alloc((void**) &bias_diff_, 
                                   conv_weight_sz[3]*sizeof(float*));

        ocall_extern_alloc((void**) &src_dump_, sizeof(float)*usr_src_desc_.get_size());

        if (weight_data != nullptr)
            std::copy(weight_data, weight_data + weight_size, weight_data_);
        
        if (bias_data != nullptr)
            std::copy(bias_data, bias_data + conv_weight_sz[3], bias_data_);
        printf("%f %f\n", weight_data_[0], bias_data_[0]);

        conv_user_weights_memory_ = memory(weights_tag, eng, (void*) weight_data_);
        conv_user_bias_memory_    = memory({{conv_bias_tz}, dt::f32, tag::x}, eng, (void*) bias_data_);
        // creating memory descriptors
        auto conv_src_md          = memory::desc({conv_src_tz}, dt::f32, tag::any);
        auto conv_bias_md         = memory::desc({conv_bias_tz}, dt::f32, tag::any);
        auto conv_weights_md      = memory::desc({conv_weights_tz}, dt::f32, tag::any);
        auto conv_dst_md          = memory::desc({conv_dst_tz}, dt::f32, tag::any);

        // convolution primitive descriptor
        auto conv_desc = convolution_forward::desc(prop_kind::forward,
                algorithm::convolution_direct, conv_src_md, conv_weights_md,
                conv_bias_md, conv_dst_md, conv_strides, conv_padding,
                conv_padding);

        conv_pd_ = convolution_forward::primitive_desc(conv_desc, eng);
        src_reorder_ = conv_pd_.src_desc() != src_tag;
        weights_reorder_ = conv_pd_.weights_desc() != weights_tag;
        primitive_attr dummy_attr;
        if (src_reorder_) {
          auto reorder_pd = reorder::primitive_desc(eng, src_tag, eng, conv_pd_.src_desc(), dummy_attr);
          net_fwd.push_back(reorder(reorder_pd));
        }
        
        if ( weights_reorder_) {
          auto reorder_pd = reorder::primitive_desc(eng, weights_tag, eng, conv_pd_.weights_desc(), dummy_attr);
          conv_weights_memory = memory(conv_pd_.weights_desc(), eng);
          net_fwd.push_back(reorder(reorder_pd));
        }

        net_fwd.push_back(convolution_forward(conv_pd_));
        eng_ = eng;
        s_ = s;
        is_first_ = is_first;
        is_train_ = training;
    }

    ~Conv2d() {
        delete [] weight_data_;
        delete [] bias_data_;
        delete [] bias_diff_;
        delete [] weight_diff_;
    }

    void forward(float* src, float* dst, bool training=false) {
        double s = get_time_force();
        int i = 0;
        memory conv_user_src_memory = memory(usr_src_desc_, eng_, (void*) src);
        memory conv_dst_memory = memory(conv_pd_.dst_desc(), eng_, (void*) dst);

        // dump src to src_dump_
        auto app = conv_pd_.src_desc().get_size();
        if (sharding) {
            this->dump_src_input((void*) src_dump_, 
                        src,
                        app);
        } else {
            this->saved_src_ = src;
        }

        memory conv_src_memory;
        
        // reshape operators                                                                                                   
        if (src_reorder_) {
          conv_src_memory = memory(conv_pd_.src_desc(), eng_);
          net_fwd.at(i).execute(s_,
                                {{DNNL_ARG_FROM, conv_user_src_memory},
                                    {DNNL_ARG_TO, conv_src_memory}});
            i++;
        } else {
          conv_src_memory = conv_user_src_memory;
        }
        if (weights_reorder_) {
          net_fwd.at(i).execute(s_,
                    {{DNNL_ARG_FROM, conv_user_weights_memory_},
                   {DNNL_ARG_TO, conv_weights_memory}});
          i++;
        } else {
          conv_weights_memory = conv_user_weights_memory_;
        }
        // actual convolution                                                                                                  
        net_fwd.at(i).execute(s_, {{DNNL_ARG_SRC, conv_src_memory},
                {DNNL_ARG_WEIGHTS, conv_weights_memory},
                {DNNL_ARG_BIAS, conv_user_bias_memory_},
                  {DNNL_ARG_DST, conv_dst_memory}});

        double e = get_time_force();
        batchnorm_fwd_time += get_elapsed_time(s, e);

    }
    memory::desc diff_src_desc() override {
        return conv_bwd_data_pd.diff_src_desc();
    }
    void update_backward(memory::desc diff_dst_tag) {
        using tag = memory::format_tag;
        using dt  = memory::data_type;


        diff_dst_tag_ = diff_dst_tag;

        // backward to gradients
        // backward memory descriptors
        auto conv_bwd_src_md = memory::desc({conv_src_tz_}, dt::f32, tag::any);
        auto conv_diff_bias_md = memory::desc({conv_bias_tz_}, dt::f32, tag::any);
        auto conv_diff_weights_md = memory::desc({conv_weights_tz_}, dt::f32, tag::any);
        auto conv_diff_dst_md = memory::desc({conv_dst_tz_}, dt::f32, tag::any);
        auto conv_diff_src_md = memory::desc({conv_src_tz_}, dt::f32, tag::any);
        printf("conv update_backward 0 %d", diff_dst_tag.get_size());
        printf("src %d %d %d %d", conv_src_tz_[0], conv_src_tz_[1], conv_src_tz_[2], conv_src_tz_[3]);
        printf("dst %d %d %d %d", conv_dst_tz_[0], conv_dst_tz_[1], conv_dst_tz_[2], conv_dst_tz_[3]);
        printf("wei %d %d %d %d", conv_weights_tz_[0], conv_weights_tz_[1], conv_weights_tz_[2], conv_weights_tz_[3]);
        printf("str %d %d", conv_strides_[0], conv_strides_[1]);
        printf("pad %d %d", conv_padding_[0], conv_padding_[1]);
        // backwards to input
        if (!is_first_) {
            // creating backward data descriptors
            auto conv_bwd_data_desc = convolution_backward_data::desc(algorithm::convolution_direct,
                                                                    conv_diff_src_md,
                                                                    conv_diff_weights_md,
                                                                    conv_diff_dst_md,
                                                                    conv_strides_,
                                                                    conv_padding_, conv_padding_
                                                                    );
            conv_bwd_data_pd = convolution_backward_data::primitive_desc(conv_bwd_data_desc, 
                                                                              eng_, 
                                                                              conv_pd_);
            
            
            if (diff_dst_tag != conv_bwd_data_pd.diff_dst_desc()) {
                
                auto reorder_pd = reorder::primitive_desc(eng_,
                                                          diff_dst_tag,
                                                          eng_,
                                                          conv_bwd_data_pd.diff_dst_desc()
                                                          );
                bwd_dst_data_reorder_ = true;
                net_bwd.push_back(reorder(reorder_pd));
            }

            if (usr_weight_desc_ != conv_bwd_data_pd.weights_desc()) {
                auto reorder_pd = reorder::primitive_desc(eng_, 
                                                          usr_weight_desc_, 
                                                          eng_, 
                                                          conv_bwd_data_pd.weights_desc()
                                                          );
                conv_bwd_weights_memory = memory(conv_bwd_data_pd.weights_desc(), 
                                                 eng_
                                                 );
                net_bwd.push_back(reorder(reorder_pd));
                bwd_weights_reorder_ = true;
            }
            printf("conv update_backward 1");

            to_back = conv_bwd_data_pd.diff_src_desc();
            net_bwd.push_back(convolution_backward_data(conv_bwd_data_pd));
           
        }
        

        printf("conv update_backward 2");

        // create backward convolution primitive descriptor
        
        auto conv_bwd_weights_desc
          = convolution_backward_weights::desc(algorithm::convolution_direct,
                                               conv_bwd_src_md, conv_diff_weights_md, conv_diff_bias_md,
                                               conv_diff_dst_md, conv_strides_, conv_padding_, conv_padding_);
        conv_bwd_weights_pd =
          convolution_backward_weights::primitive_desc(conv_bwd_weights_desc, eng_, conv_pd_);
        primitive_attr dummy_attr;
    
        printf("conv update_backward 3");

        // reshape conv_src
        if (conv_bwd_weights_pd.src_desc() != conv_pd_.src_desc()) {
          auto reorder_pd = reorder::primitive_desc(eng_, conv_pd_.src_desc(),
                                                    eng_, conv_bwd_weights_pd.src_desc(), dummy_attr);
          
          net_bwd.push_back(reorder(reorder_pd));
          bwd_src_reorder_ = true;
        }

        printf("conv update_backward 4");

        if (diff_dst_tag != conv_bwd_weights_pd.diff_dst_desc()) {
          auto reorder_pd = reorder::primitive_desc(eng_, diff_dst_tag,
                                                   eng_, conv_bwd_weights_pd.diff_dst_desc(), dummy_attr);
          
          net_bwd.push_back(reorder(reorder_pd));
          bwd_dst_reorder_ = true;
        }
        printf("conv update_backward 5");
        auto tmp = convolution_backward_weights(conv_bwd_weights_pd);
        printf("conv update_backward 6");

        net_bwd.push_back(tmp);

        conv_diff_weights_memory = memory(conv_bwd_weights_pd.diff_weights_desc(), eng_);
        printf("conv update_backward 7");

        // inplace reorder
        if (usr_weight_desc_ != conv_bwd_weights_pd.diff_weights_desc()) {
          auto reorder_pd = reorder::primitive_desc(eng_, conv_bwd_weights_pd.diff_weights_desc(),
                                                    eng_, usr_weight_desc_, dummy_attr);

          net_bwd.push_back(reorder(reorder_pd));
          bwd_gradients_reorder_ = true;
        }

    }
    int type() {return 32;}


    void backward(float* conv_bwd_data_ptr, float* diff_dst) {
        double s = get_time_force();

        using tag = memory::format_tag;
        using dt = memory::data_type;
        int i = 0;
        memory conv_user_diff_dst       = memory(diff_dst_tag_, eng_, (void*) diff_dst);
        
        int src_diff_size;
        memory conv_bwd_data_memory;
        if (!is_first_) {
            src_diff_size = conv_bwd_data_pd.diff_src_desc().get_size() * sizeof(float);

        
            conv_bwd_data_memory = memory(conv_bwd_data_pd.diff_src_desc(), 
                                                 eng_,
                                                 (void*) conv_bwd_data_ptr

                                             );
        }
        memory conv_diff_weights_memory = memory(usr_weight_desc_, eng_, (void*)weight_diff_);
        memory conv_diff_bias_memory    = memory({{conv_bias_tz_}, dt::f32, tag::x}, eng_, (void*) bias_diff_);
        if (!is_first_) {
            memory conv_diff_dst_memory;
            if (bwd_dst_data_reorder_) {
                conv_diff_dst_memory = memory(conv_bwd_data_pd.diff_dst_desc(), eng_);
                net_bwd.at(i).execute(s_, {{DNNL_ARG_FROM, conv_user_diff_dst}, 
                                           {DNNL_ARG_TO,   conv_diff_dst_memory}});

                i++;
            } else {

                conv_diff_dst_memory = conv_user_diff_dst;
            }

            if (bwd_weights_reorder_) {
                net_bwd.at(i).execute(s_, {{DNNL_ARG_FROM, conv_user_weights_memory_},
                    {DNNL_ARG_TO, conv_bwd_weights_memory}}) ;
                i++;
            } else {
                conv_bwd_weights_memory = conv_user_weights_memory_;
            }          



            net_bwd.at(i).execute(s_, {{DNNL_ARG_DIFF_DST, conv_diff_dst_memory},
              {DNNL_ARG_WEIGHTS, conv_bwd_weights_memory},
                {DNNL_ARG_DIFF_SRC, conv_bwd_data_memory}
            });
            i++;
        }
        
        // load src to src_dump_
        float* src_ptr = nullptr; 
       if (sharding) {
            src_ptr        = new float[usr_src_desc_.get_size() / sizeof(float) + 8];
            this->saved_src_ = ALIGN32(src_ptr);
            this->load_src_input((void*) this->saved_src_, (void*) this->src_dump_, 
                                                    this->conv_pd_.src_desc().get_size());
       }
        memory conv_src_memory = memory(usr_src_desc_, eng_, this->saved_src_);
        memory conv_bwd_src_memory;
        if (bwd_src_reorder_) {
            conv_bwd_src_memory = memory(conv_bwd_weights_pd.src_desc(), eng_);
            net_bwd.at(i).execute(s_, {{DNNL_ARG_FROM, conv_src_memory},
                {DNNL_ARG_TO, conv_bwd_src_memory}});
          i++;
        } else {
            conv_bwd_src_memory = conv_src_memory;
        }

        memory conv_diff_dst_memory;

        if (bwd_dst_reorder_) {
            conv_diff_dst_memory = memory(conv_bwd_weights_pd.diff_dst_desc(), eng_);
            net_bwd.at(i).execute(s_, {{DNNL_ARG_FROM, conv_user_diff_dst},
                {DNNL_ARG_TO, conv_diff_dst_memory}});
          i++;
        } else {
          conv_diff_dst_memory = conv_user_diff_dst;
        }

        // actual gradients
        net_bwd.at(i).execute(s_, {{DNNL_ARG_DIFF_DST, conv_diff_dst_memory},
              {DNNL_ARG_SRC, conv_bwd_src_memory},
                    {DNNL_ARG_DIFF_WEIGHTS, conv_diff_weights_memory},                              
                  {DNNL_ARG_DIFF_BIAS, conv_diff_bias_memory}});
        i++;

        double e = get_time_force();
        batchnorm_bwd_time += get_elapsed_time(s, e);
        

    }  

    int out_size() {
        return output_size_;
    }

    int input_size() {
        return input_size_;
    }

    memory::desc dst_desc() {
        return conv_pd_.dst_desc();
    }
  

    float* weight_data_;
    float* bias_data_;
    float* bias_diff_;
    float* weight_diff_;
    float* src_dump_;

    std::vector<primitive> net_fwd;
    std::vector<primitive> net_bwd;
    engine eng_;
    stream s_;
    int output_size_;
    int input_size_;
    convolution_forward::primitive_desc          conv_pd_;
    convolution_backward_weights::primitive_desc conv_bwd_weights_pd;
    convolution_backward_data::primitive_desc    conv_bwd_data_pd;
    bool is_first_ = false;
    bool is_train_ = false;
    bool src_reorder_ = false;
    bool weights_reorder_ = false;
    bool bwd_src_reorder_ = false;
    bool bwd_dst_reorder_ = false;
    bool bwd_gradients_reorder_ = false;
    bool bwd_weights_reorder_ = false;
    bool bwd_dst_data_reorder_ = false;
    memory conv_weights_memory;
    
    
    memory conv_bwd_weights_memory;
    memory conv_diff_weights_memory;
    memory::desc to_back;

    memory::desc usr_src_desc_; 
    memory::desc usr_weight_desc_;
    memory::desc diff_dst_tag_;
    
    memory conv_user_weights_memory_;
    memory conv_user_bias_memory_;

    memory::dims conv_src_tz_ ;
    memory::dims conv_bias_tz_ ;
    memory::dims conv_weights_tz_;
    memory::dims conv_dst_tz_;
    memory::dims conv_padding_;
    memory::dims conv_strides_;

};
#endif