#pragma once
#include <vector>
#include <fstream>
#include <iostream>
#include <cstring>
#include <algorithm>
#include <openacc.h>
#include "event.h"
#include "tensor.h"

#define Model std::vector<Module* >

class Module {
	public :
		virtual std::vector<Event> forward(const std::vector<Event> &unordered_event_vec);
		virtual std::vector<std::vector<Event> > forward(const std::vector<std::vector<Event> > &unordered_event_vec);
		virtual std::vector<Event> event_arrive(const std::vector<Event> &unordered_event_vec) = 0;
		virtual void load_weight(const char * path);
		virtual Module* distributed_clone(); // not necessary for non-distributed Module
		virtual ~Module() {}
};

class Flatten2d : public Module {
	public :
		int C, H, W;
		Flatten2d() {}
		Flatten2d(int C, int H, int W) : C(C), H(H), W(W) {}
		std::vector<Event> event_arrive(const std::vector<Event> &unordered_event_vec);
		Module* distributed_clone();
};

class DropOut : public Module {
	public :
		double p_dropout;
		DropOut() {}
		DropOut(double p_dropout) : p_dropout(p_dropout){}
		std::vector<Event> event_arrive(const std::vector<Event> &unordered_event_vec);
		Module* distributed_clone();
};

class SumPool2d : public Module {
	public :
		int k_size;
		SumPool2d() {}
		SumPool2d(int k_size) : k_size(k_size) {}
		std::vector<Event> event_arrive(const std::vector<Event> &unordered_event_vec);
		Module* distributed_clone();
};

class TimeShift : public Module {
	public :
		float_T mean, stddev;
		TimeShift() {}
		TimeShift(float_T mean, float_T stddev) : mean(mean), stddev(stddev) {}
		std::vector<Event> event_arrive(const std::vector<Event> &unordered_event_vec);
		std::vector<Event> forward(const std::vector<Event> &unordered_event_vec);
		std::vector<std::vector<Event> > forward(const std::vector<std::vector<Event> > &unordered_event_vec);
		Module* distributed_clone();
};

class SpikingConv2d : public Module {
	public :
		int in_C, out_C, k_size, padding, out_H, out_W, stride;
		int avg_pool_size;
		float_T **** w_rev, *** u;
		float_T Vth, min_v_mem;
		bool soft_reset;

		SpikingConv2d() {}
		SpikingConv2d(int in_C, int out_C, int k_size, int padding, int out_H, int out_W, float_T Vth=1.0, bool soft_reset=0, int avg_pool_size=1, int stride=1, float_T min_v_mem=-1.0);
		~SpikingConv2d();
		
		void load_weight(const char * path, float_T div=1.0);
	    std::vector<Event > event_arrive(const std::vector<Event > &vec);
		Module* distributed_clone();
};

class SpikingLinear : public Module {
	public :
		int in_N, out_N;
		float_T ** w, * u, Vth, min_v_mem;
		bool soft_reset;

		SpikingLinear() {}
		SpikingLinear(int in_N, int out_N, float_T Vth=1.0, bool soft_reset=0, float_T min_v_mem=-1.0);

		void load_weight(const char * path, float_T weight_scaling=1.0);
	    std::vector<Event > event_arrive(const std::vector<Event > &vec);
	    Module* distributed_clone();
};

class ReadOut {
	public:
		virtual int to_class(Model & net, const std::vector<Event > &vec) {return -1;}
		virtual ReadOut* distributed_clone() {return this;}; // not necessary for non-distributed Module
		virtual ~ReadOut() {}
};

class MembraneArgmax : public ReadOut {
	public:
		float_T * last_accumulative_membrane;
		int out_N;
		MembraneArgmax(int out_N);
		~MembraneArgmax();

		int to_class(Model & net, const std::vector<Event > &vec);
		ReadOut* distributed_clone();
};

class SpikeCountArgmax : public ReadOut {
	public:
		~SpikeCountArgmax() {}
		int to_class(Model & net, const std::vector<Event > &vec);
		ReadOut* distributed_clone();
};