#include "nn.h"

std::vector<std::vector<Event> > Module::forward(const std::vector<std::vector<Event> > &unordered_event_vec) {
    std::vector<std::vector<Event> > ret;
    std::vector<Event > x;
    for (const auto &vec : unordered_event_vec) {
        x = event_arrive(vec);
        if (!x.empty())
            ret.push_back(x);
    }
    return ret;
}

std::vector<Event> Module::forward(const std::vector<Event> &unordered_event_vec) {
    std::vector<Event> ret;
    std::vector<Event > x;
    for (const auto &e : unordered_event_vec) {
        x = event_arrive(std::vector<Event>{e});
        if (!x.empty())
            ret.insert(ret.end(), x.begin(), x.end());
    }
    return ret;
}

// Empty function for modules not having any weight
void Module::load_weight(const char * path) {}

Module* Module::distributed_clone() {
    return NULL;
}

std::vector<Event> Flatten2d::event_arrive(const std::vector<Event> &vec) {
    std::vector<Event> ret;
    for (auto e : vec) {
        ret.push_back(Event(e.c * H * W + e.x * W + e.y, e.ts, e.A));
    }
    return ret;
}

Module* Flatten2d::distributed_clone() {
    Flatten2d * ret = new Flatten2d();
    *ret = *this;
    return ret;
}

std::vector<Event> DropOut::event_arrive(const std::vector<Event> &vec) {
    std::vector<Event> ret;
    for (auto e : vec) {
        if (randu() > p_dropout)
            ret.push_back(e);
    }
    return ret;
}

Module* DropOut::distributed_clone() {
    DropOut * ret = new DropOut();
    *ret = *this;
    return ret;
}

std::vector<Event> SumPool2d::event_arrive(const std::vector<Event> &vec) {
    std::vector<Event> ret;
    for (auto e : vec) {
        ret.push_back(Event(e.c, e.x / k_size, e.y / k_size, e.ts));
    }
    return ret;
}

Module* SumPool2d::distributed_clone() {
    SumPool2d * ret = new SumPool2d();
    *ret = *this;
    return ret;
}

std::vector<Event> TimeShift::event_arrive(const std::vector<Event> &unordered_event_vec) {
    std::vector<Event> ret;
    for (auto e : unordered_event_vec) {
        e.ts += randn(mean, stddev);
        ret.push_back(e);
    }
    return ret;
}

std::vector<Event> TimeShift::forward(const std::vector<Event> &unordered_event_vec) {
    std::vector<Event> ret;
    for (auto e : unordered_event_vec) {
        e.ts += randn(mean, stddev);
        ret.push_back(e);
    }
    std::sort(ret.begin(), ret.end());
    return ret;
}

std::vector<std::vector<Event> > TimeShift::forward(const std::vector<std::vector<Event> > &unordered_event_vec) {
    return unordered_event_vec;
}

Module* TimeShift::distributed_clone() {
    TimeShift * ret = new TimeShift();
    *ret = *this;
    return ret;
}

SpikingConv2d::SpikingConv2d(int in_C, int out_C, int k_size, int padding, int out_H, int out_W, float_T Vth, bool soft_reset, int avg_pool_size, int stride, float_T min_v_mem) : in_C(in_C), out_C(out_C), k_size(k_size), padding(padding), out_H(out_H), out_W(out_W), Vth(Vth), soft_reset(soft_reset), avg_pool_size(avg_pool_size), stride(stride), min_v_mem(min_v_mem) {

    alloc_tensor_4d(w_rev, in_C, out_C, k_size, k_size, 0);
    alloc_tensor_3d(u, out_C, out_H, out_W, 0);
    
}

SpikingConv2d::~SpikingConv2d() {
    if (w_rev != NULL) free_tensor_4d(w_rev, in_C, out_C, k_size, k_size);
    if (u != NULL) free_tensor_3d(u, out_C, out_H, out_W);
}

void SpikingConv2d::load_weight(const char * path, float_T div) {
    
    std::ifstream w_file(path);

    if (!w_file.is_open()) {
        std::cerr << "Failed to open weight file at " << std::string(path) << "\n";
        exit(0);
    }
    
    for (int i = 0; i < out_C; ++i) {
        for (int j = 0; j < in_C; ++j) {
            for (int k = 0; k < k_size; ++k) {
                for (int l = 0; l < k_size; ++l) {
                    w_file >> w_rev[j][i][k_size - k - 1][k_size - l - 1];
                    ((w_rev[j][i][k_size - k - 1][k_size - l - 1] /= avg_pool_size) /= avg_pool_size) /= div;
                }
            }
        }
    }
}

std::vector<Event > SpikingConv2d::event_arrive(const std::vector<Event > &vec) {
    std::vector<Event > ret;
    
    int u_min_x, u_max_x, u_min_y, u_max_y;
    int wx_min, wy_min;

    for (auto e : vec) {
        e.x /= avg_pool_size;
        e.y /= avg_pool_size;

        u_max_x = e.x + padding;
        u_max_y = e.y + padding;
        u_min_x = u_max_x - k_size + 1;
        u_min_y = u_max_y - k_size + 1;
    
        wx_min = 0;
        wy_min = 0;

        if (u_min_x < 0) wx_min = -u_min_x, u_min_x = 0;
        if (u_max_x >= out_H * stride) u_max_x = out_H * stride - 1;
        if (u_min_y < 0) wy_min = -u_min_y, u_min_y = 0;
        if (u_max_y >= out_W * stride) u_max_y = out_W * stride - 1;

        // New membrane update with stride
        
        int x_shift = ((u_min_x + stride - 1) / stride) * stride - u_min_x;
        int y_shift = ((u_min_y + stride - 1) / stride) * stride - u_min_y;
        

        for (int i = 0; i < out_C; ++i) {
            for (int j = (u_min_x + x_shift) / stride, _j = wx_min + x_shift; j <= u_max_x / stride; ++j, _j += stride) {
                // if (j % stride != 0) continue;
                for (int k = (u_min_y + y_shift) / stride, _k = wy_min + y_shift; k <= u_max_y / stride; ++k, _k += stride) {
                    // if (k % stride != 0) continue;
                    u[i][j][k] += w_rev[e.c][i][_j][_k];
                }
            }
        }
    }
    
    for (auto e : vec) {
        e.x /= avg_pool_size;
        e.y /= avg_pool_size;

        u_max_x = e.x + padding;
        u_max_y = e.y + padding;
        u_min_x = u_max_x - k_size + 1;
        u_min_y = u_max_y - k_size + 1;
    
        wx_min = 0;
        wy_min = 0;

        if (u_min_x < 0) wx_min = -u_min_x, u_min_x = 0;
        if (u_max_x >= out_H * stride) u_max_x = out_H * stride - 1;
        if (u_min_y < 0) wy_min = -u_min_y, u_min_y = 0;
        if (u_max_y >= out_W * stride) u_max_y = out_W * stride - 1;

        int x_shift = ((u_min_x + stride - 1) / stride) * stride - u_min_x;
        int y_shift = ((u_min_y + stride - 1) / stride) * stride - u_min_y;
        for (int i = 0; i < out_C; ++i) {
            for (int j = (u_min_x + x_shift) / stride, _j = wx_min + x_shift; j <= u_max_x / stride; ++j, _j += stride) {
                // if (j % stride != 0) continue;
                for (int k = (u_min_y + y_shift) / stride, _k = wy_min + y_shift; k <= u_max_y / stride; ++k, _k += stride) {
                    // if (k % stride != 0) continue;
                    float_T & u_neuron = u[i][j][k];
                    
                    if (u_neuron >= Vth) {
                        float_T num = u_neuron / Vth;
                        ret.push_back(Event(i, j, k, e.ts));
                        for (int _ = 1; _ <= num; ++_) {
                            // ret.push_back(Event(i, j, k, e.ts));
                            // ret.push_back(Event(i, j / stride, k / stride, e.ts));
                            u_neuron = soft_reset ? u_neuron - Vth : 0;
                        }
                    }
                    if (u_neuron < min_v_mem) {
                        u_neuron = min_v_mem;
                    }
                }
            }
        }
    }
    
    return ret;
}

void SpikingLinear::load_weight(const char * path, float_T weight_scaling) {
			
    std::ifstream w_file(path);
    
    w = (double **) malloc (in_N * sizeof(double *));
    for (int i = 0; i < in_N; ++i) {
        w[i] = (double *) malloc (out_N * sizeof(double));
        for (int j = 0; j < out_N; ++j) {
            w[i][j] = 0;
        }
    }
    
    for (int i = 0; i < out_N; ++i)
        for (int j = 0; j < in_N; ++j)
            w_file >> w[j][i], w[j][i] /= weight_scaling;

}

Module* SpikingConv2d::distributed_clone() {
    SpikingConv2d * ret = new SpikingConv2d();
    *ret = *this;
    alloc_tensor_3d(ret->u, out_C, out_H, out_W, 0);
    clone_tensor_3d(ret->u, out_C, out_H, out_W, this->u);
    return ret;
}

SpikingLinear::SpikingLinear(int in_N, int out_N, float_T Vth, bool soft_reset, float_T min_v_mem) :
    in_N(in_N), out_N(out_N), Vth(Vth), soft_reset(soft_reset), min_v_mem(min_v_mem) {

    w = (float_T **) malloc (in_N * sizeof(float_T *));
    for (int i = 0; i < in_N; ++i) {
        w[i] = (float_T *) malloc (out_N * sizeof(float_T));
        memset(w[i], 0, out_N * sizeof(float_T)); // initialize
    }
    
    u = (float_T *) malloc (out_N * sizeof(float_T));
    memset(u, 0, out_N * sizeof(float_T));
}


std::vector<Event > SpikingLinear::event_arrive(const std::vector<Event > &vec) {
    
    std::vector<Event > ret;
    if (vec.empty())
        return ret;

    int ts = 0;

    for (auto e : vec) {
        for (int i = 0; i < out_N; ++i) {
            u[i] += w[e.x][i];
        }
    }

    ts = vec[0].ts;

    for (int i = 0; i < out_N; ++i) {
        float_T & u_neuron = u[i];
        
        if (u_neuron >= Vth) {
            float_T num = u_neuron / Vth;
            ret.push_back(Event(i, ts, 1));
            for (int _ = 1; _ <= num; ++_) {
                // ret.push_back(Event(i, ts, 1));
                u_neuron = soft_reset ? u_neuron - Vth : 0;
            }
        }
        if (u_neuron < min_v_mem) 
            u_neuron = min_v_mem;
    }
    return ret;
}

Module* SpikingLinear::distributed_clone() {
    SpikingLinear * ret = new SpikingLinear();
    *ret = *this;
    ret->u = (float_T *) malloc (out_N * sizeof(float_T));
    memcpy(ret->u, u, out_N * sizeof(float_T));
    return ret;
}

MembraneArgmax::MembraneArgmax(int out_N) : out_N(out_N) {
    last_accumulative_membrane = (float_T*)malloc(sizeof(float_T) * (out_N));
    memset(last_accumulative_membrane, 0, sizeof(float_T) * out_N);
}

MembraneArgmax::~MembraneArgmax() {
    free(last_accumulative_membrane);
}

int MembraneArgmax::to_class(Model & net, const std::vector<Event > &vec) {
    float_T mx_value = -1e300, tmp;
    int mx_index = -1;

    SpikingLinear * LastLayer = dynamic_cast<SpikingLinear*>(net.back());

    for (int i = 0; i < out_N; ++i) {
        // tmp = LastLayer->u[i] - last_accumulative_membrane[i];
        tmp = LastLayer->u[i];
        if(tmp > mx_value) {
            mx_value = tmp;
            mx_index = i;
        }
    }
    // memcpy(last_accumulative_membrane, LastLayer->u, sizeof(float_T) * out_N);
    memset(LastLayer->u, 0, sizeof(float_T) * out_N);
    return mx_index;
}

ReadOut* MembraneArgmax::distributed_clone() {
    MembraneArgmax * ret = new MembraneArgmax(out_N);
    return ret;
}

int SpikeCountArgmax::to_class(Model & net, const std::vector<Event > &vec) {
    return event_argmax1d(vec);
}

ReadOut* SpikeCountArgmax::distributed_clone() {
    SpikeCountArgmax * ret = new SpikeCountArgmax();
    return ret;
}