#ifndef __NeuralNetwork_h__
#define __NeuralNetwork_h__
#include "Common.h"

class SimpleNetwork
{
public:
	static const int linear=0;
	static const int relu=1;
	static const int leaky_relu=2;

	int L=0;
	Array<VectorX> x;		////L, x data
	Array<int> xs;			////L, x size
	Array<int> type;		////L-1, layer type
	int nl=0;				////number of linear layers
	int nX=0;				////input dimension
	int nC=0;				////output dimension
	int nt=0;				////parameter dimension
	Array<int> n2l;			////L-1, idx mapping from L to nl
	Array<int> l2n;			////nl, idx mapping from nl to L
	Array<MatrixX> A;		////nl, linear transform matrix mapping values from i to i+1 
	Array<VectorX> b;		////nl, linear bias
	Array<Vector2i> As;		////nl, A size
	Array<int> A_ptr;		////nl+1, starting idx of each A in nt
	Array<MatrixX> g;		////L, each g stores dCdx1 for its layer
	MatrixX Ct;				////[nC,nt]

	void Initialize(const Array<int>& _size,const Array<int>& _type)	////type: 0-linear,1-ReLU
	{
		xs=_size;
		type=_type;

		L=(int)xs.size();
		nX=xs[0];
		nC=xs[L-1];

		n2l.resize(L-1);
		nl=0;
		for(int i=0;i<L-1;i++){
			if(type[i]==linear){
				nl++;
				l2n.push_back(i);
				n2l[i]=(int)l2n.size()-1;}
			else n2l[i]=-1;}
		/*std::cout << "n2l:=" << std::endl;
		for(int i=0;i<n2l.size();i++){ std::cout << n2l[i] << std::endl; }
		std::cout << "l2n:=" << std::endl;
		for (int i = 0; i < l2n.size(); i++) { std::cout << l2n[i] << std::endl; }*/


		x.resize(L);
		for(int i=0;i<L;i++){
			x[i].resize(xs[i]);}

		A.resize(nl);
		As.resize(nl);
		b.resize(nl);
		A_ptr.resize(nl+1);
		A_ptr[0]=0;
		for(int i=0;i<L-1;i++){
			if(type[i]!=linear)continue;
			Vector2i s=Vector2i(xs[i+1],xs[i]);
			int a_i=n2l[i];
			As[a_i]=s;
			A_ptr[a_i+1]=A_ptr[a_i]+s[0]*s[1]+s[0];
			A[a_i].resize(s[0],s[1]);
			b[a_i].resize(s[0]);}
		nt=A_ptr[nl];
		//for(int i=0;i<nl;i++){ std::cout << "Aptr[" << i << "]=:" << A_ptr[i] << std::endl; }

		g.resize(L);
		for(int i=0;i<L-1;i++){
			g[i].resize(nC,xs[i]);
			g[i].fill((real)0);}
		g[L-1].resize(nC,nC);
		g[L-1].setIdentity();

		Ct.resize(nC,nt);Ct.fill((real)0);

		for(int l=0;l<nl;l++){
			for(int i=0;i<As[l][0];i++){
				for(int j=0;j<As[l][1];j++){
					A[l](i,j)=(real)(rand()%20000-10000)/(real)10000*sqrt(6.0/(double)(As[l][0]));}}
			for (int i = 0; i < As[l][0]; i++) {
				b[l](i) = (real)(rand() % 20000 - 10000) / (real)10000 * sqrt(6.0 / (double)(As[l][0]));
			}		
		}
	}

	void Check_Numerical_Derivatives()
	{
		Forward();
		VectorX c;Get_Output(c);

		MatrixX nCt=Ct;
		real dx=(real).00001;
		for(int l=0;l<nl;l++){
			for(int i=0;i<As[l][0];i++){
				for(int j=0;j<As[l][1];j++){
					real old=A[l](i,j);
					A[l](i,j)+=dx;
					Forward();
					VectorX dc;Get_Output(dc);
					dc-=c;dc/=dx;
					int c_idx=A2Ct_Idx(l,i,j);
					nCt.col(c_idx)=dc;
					A[l](i,j)=old;}}
			for(int i=0;i<As[l][0];i++){
				real old=b[l](i);
				b[l][i]+=dx;
				Forward();
				VectorX dc;Get_Output(dc);
				dc-=c;dc/=dx;
				int c_idx=b2Ct_Idx(l,i);	
				nCt.col(c_idx)=dc;
				b[l](i)=old;}}
		std::cout<<"nCt:\n"<<nCt<<std::endl;
	}

	void Forward_Linear(const VectorX& x1,const MatrixX& A,const VectorX& b,VectorX& x2)
	{
		x2=A*x1+b;
	}

	void Forward_ReLU(const VectorX& x1,VectorX& x2)
	{
		for(int i=0;i<x1.size();i++){
			x2[i]=std::max((real)0,x1[i]);}
	}

	void Backward_Linear(const MatrixX& input,const MatrixX& A,MatrixX& output)
	{
		output=input*A;
		//std::cout<<"gi: "<<input.rows()<<", "<<input.cols()<<std::endl;
		//std::cout<<"A: "<<A.rows()<<", "<<A.cols()<<std::endl;
		//std::cout<<"go: "<<output.rows()<<", "<<output.cols()<<std::endl;
		//std::cout<<output<<std::endl;
	}

	void Backward_ReLU(const MatrixX& input,const VectorX& x1,MatrixX& output)
	{
		VectorX D(x1.size());
		for(int i=0;i<x1.size();i++)D[i]=x1[i]<(real)0?(real)0:(real)1;
		output=input*D.asDiagonal();
	}

	////update values in x
	void Forward()
	{
		for(int i=0;i<L-1;i++){
			switch(type[i]){
			case linear:Forward_Linear(x[i],A[n2l[i]],b[n2l[i]],x[i+1]);break;
			case relu:Forward_ReLU(x[i],x[i+1]);break;}}
	}

	////update values in the array of Lt
	void Backward()
	{
		for(int i=L-2;i>=0;i--){
			switch(type[i]){
			case linear:Backward_Linear(g[i+1],A[n2l[i]],g[i]);break;
			case relu:Backward_ReLU(g[i+1],x[i],g[i]);break;}}
	}

	int A2Ct_Idx(const int l,const int i,const int j)
	{
		return A_ptr[l]+i*As[l][1]+j;
	}

	int b2Ct_Idx(const int l,const int i)
	{
		return A_ptr[l]+As[l][0]*As[l][1]+i;
	}

	////update Ct, Ct has dim (nc,nt)
	void dCdt()
	{
		Forward();
		Backward();

		for(int l=nl-1;l>=0;l--){
			MatrixX& Cx2=g[l2n[l]+1];
			VectorX& x1=x[l2n[l]];
			for(int i=0;i<As[l][0];i++){
				for(int j=0;j<As[l][1];j++){
					int Ct_idx=A2Ct_Idx(l,i,j);
					Ct.col(Ct_idx)=Cx2.col(i)*x1[j];}}
			for(int i=0;i<As[l][0];i++){
				int Ct_idx=b2Ct_Idx(l,i);
				Ct.col(Ct_idx)=Cx2.col(i);}}
	}

	////update g, g[0] has dim (nc,xs[0])
	void dCdX()
	{
		Forward();
		Backward();
	}

	void Set_X(const real* x1)
	{
		for(int i=0;i<nX;i++)x[0][i]=x1[i];
	}

	void Set_t(const real* t)
	{
		for(int l=0;l<nl;l++){
			for(int i=0;i<As[l][0];i++){
				for(int j=0;j<As[l][1];j++){
					int Ct_idx=A2Ct_Idx(l,i,j);
					A[l](i,j)=t[Ct_idx];}}
			for(int i=0;i<As[l][0];i++){
				int Ct_idx=b2Ct_Idx(l,i);
				b[l](i)=t[Ct_idx];}}
	}

	void Get_t(real* t)
	{
		for(int l=0;l<nl;l++){
			for(int i=0;i<As[l][0];i++){
				for(int j=0;j<As[l][1];j++){
					int Ct_idx=A2Ct_Idx(l,i,j);
					t[Ct_idx]=A[l](i,j);}}
			for(int i=0;i<As[l][0];i++){
				int Ct_idx=b2Ct_Idx(l,i);
				t[Ct_idx]=b[l](i);}}	
	}

	void Get_Output(VectorX& output)
	{output=x[L-1];}

	void Print_X()
	{
		std::cout<<"kernel x: ";
		for(int i=0;i<nX;i++)std::cout<<x[0][i]<<", ";std::cout<<std::endl;
	}

	void Print_t()
	{
		std::cout<<"kernel t: ";
		for(int l=0;l<nl;l++){
			for(int i=0;i<As[l][0];i++){
				for(int j=0;j<As[l][1];j++){
					int Ct_idx=A2Ct_Idx(l,i,j);
					std::cout<<A[l](i,j)<<", ";}}
			for(int i=0;i<As[l][0];i++){
				int Ct_idx=b2Ct_Idx(l,i);
				std::cout<<b[l](i)<<", ";}}
		std::cout<<std::endl;	
	}
};
#endif