#ifndef __NeuralPDE_h__
#define __NeuralPDE_h__
#include "OptimizerMma.h"
#include "OptimizerIpOpt.h"
#include "NeuralNetwork.h"
#include "RandomNumber.h"
#include "TypeFunc.h"
#include "Constants.h"
#include "Particles.h"
#include "Mesh.h"
#include "Grid.h"

const int DIMENSION=2;
const int AXIS_SIZE = 49;
const real DOMAIN_LENGTH=(real)1;
const real DX=DOMAIN_LENGTH/(real)(AXIS_SIZE-1);
const real ONE_OVER_DX=(real)1/(DX);
const real ONE_OVER_DX_2=(real)1/(DX*DX);
const real IPOPT_TOL=(real)1e-10;

const int DATA_SAMPLE_SIZE=8;

bool USE_RANDOM_SAMPLE =false;
const int SAMPLE_FRACTION=32;

const bool USE_KERNEL_X=true;				////whether the kernel values depend on the previous iteration values
const bool USE_KERNEL_P=true;				////whether the kernel values depend on position
const bool USE_KERNEL_SYMMETRY=false;		////whether the kernel is symmetric
const int SOLVER_ITERATION_NUM=2;
 
//Array<int> NN_SIZE={6,6,6,6,6,3};
Array<int> NN_SIZE = { 3,5,5,5,5,3 };
Array<int> NN_TYPE = { 0,1,0,1,0};

//Array<int> NN_SIZE={3,10,10,10,10,10,10,3};				////dimension of vectors on each layer
//Array<int> NN_TYPE={0,1,0,1,0,1,0,1,0};					////0-linear, 1-ReLU

RandomNumber random;
inline void Solve(const MatrixX& A,const VectorX& b,VectorX& x){x=A.fullPivLu().solve(b);}
inline void Solve(const MatrixX& A,const MatrixX& b,MatrixX& x){x=A.fullPivLu().solve(b);}

template<int d> class PDEExample
{
	using VecD=typename If<d==1,Vector3,Vector5>::Type;
	using VecX=typename If<d==1,Vector3,VectorX>::Type;
	Typedef_VectorDii(d);
public:
	const Vector<int, d> counts = Vector<int, d>::Ones()* AXIS_SIZE;
	Grid<d> grid=Grid<d>(counts,DX);
	
	void Build_Data_Sample(VectorX& b,VectorX& tar,const int data_idx)
	{
		Build_Data_Sample_11(b,tar,data_idx);
	}

	////f=(ax)^2, lap f=-2, x=(0:127)/128
	void Build_Data_Sample_1(VectorX& b,VectorX& tar,const int data_idx)
	{
		const int n=(int)b.size();

		real a=(real)(data_idx+2);

		for(int i=1;i<n-1;i++){b[i]=(real)-2.*a*a;}
		b[0]=(real)0;
		b[n-1]=pow((n-1)*a*DX,2);

		tar.resize(n);
		for(int i=0;i<n;i++)tar[i]=(real)pow(i*a*DX,2);
		//std::cout<<"tar: "<<tar.transpose()<<std::endl;
	}

	////y=sin(ax), lap y=-a^2sin(ax), x=0:31
	void Build_Data_Sample_2(VectorX& b,VectorX& tar,const int data_idx)
	{
		const int n=(int)b.size();
		real a=(real)(data_idx+1)*(real)1/(real)32;

		auto x = [](int i){return (real)i;};
		auto f = [&](real x0){return sin(a*x0);};
		auto lap = [&](real x0){return -a*a*sin(a*x0);};

		for(int i=1;i<n-1;i++){b[i]=-lap(x(i));}
		b[0]=f(x(0));
		b[n-1]=f(x(n-1));

		tar.resize(n);
		for (int i = 0; i < n; i++)tar[i] = f(x(i))+random.Value()*(real).2-(real).1;
		//std::cout<<"tar: "<<tar.transpose()<<std::endl;
	}

	////lap x=0, with x0=0, x1=random, grid_size=32;
	void Build_Data_Sample_3(VectorX& b,VectorX& tar,const int data_idx)
	{
		const int n=(int)b.size();

		MatrixX A(n,n);
		A.setZero();
		for(int i=1;i<n-1;i++){
			Vector3 coef=Vector3(-1.,2.,-1.)*ONE_OVER_DX_2;
			A(i,i-1)=coef[0];
			A(i,i)=coef[1];
			A(i,i+1)=coef[2];}
		A(0,0)=(real)1;			////Dirichlet
		A(n-1,n-1)=(real)1;		////Dirichlet

		b.fill((real)0);
		b[0]=(real)0;
		b[n-1]=random.Value();

		Solve(A,b,tar);
	}

	////lap x=0 with varying coef, with x0=random,x1=random
	////\nabla\cdot(1+|\pi p|)\nabla \mx=0 ,grid_size=32;
	void Build_Data_Sample_4(VectorX& b,VectorX& tar,const int data_idx)
	{
		const int n=(int)b.size();

		auto x=[](int i){return (real)i*ONE_OVER_DX;};
		auto coef=[](real x){return abs(pi*x*(real)1);};

		MatrixX A(n,n);
		A.setZero();
		for(int i=1;i<n-1;i++){
			real u1=coef((real).5*(x(i-1)+x(i)));	
			real u2=coef((real).5*(x(i)+x(i+1)));	

			Vector3 c=Vector3(-u1,u1+u2,-u2)*ONE_OVER_DX_2;
			A(i,i-1)=c[0];
			A(i,i)=c[1];
			A(i,i+1)=c[2];}
		A(0,0)=(real)1;			////Dirichlet
		A(n-1,n-1)=(real)1;		////Dirichlet

		
		b[0]=random.Value();
		b[n-1]=random.Value();
		for(int i=1;i<n-1;i++)b[i]=0.;

		Solve(A,b,tar);
	}

	////nonlinear \nabla\cdot(1+|x|+sin(|x|*.001))\nabla x=0 with varying coef, with x0=random,x1=random
	//// grid_size=32;
	void Build_Data_Sample_5(VectorX& b,VectorX& tar,const int data_idx)
	{
		const int n=(int)b.size();

		int max_iter=100;
		tar.fill((real)0);

		b.fill((real)0);
		b[0]=random.Value();
		b[n-1]=random.Value();

		real tol=(real)1e-10;

		bool converge=false;
		for(int iter=0;iter<max_iter;iter++){
			MatrixX A(n,n);
			A.setZero();
			for(int i=1;i<n-1;i++){
				real u1=(real).5*(tar[i-1]+tar[i]);
				real u2=(real).5*(tar[i]+tar[i+1]);
				real a1=(1.+abs(u1)+sin(u1*.001));
				real a2=(1.+abs(u2)+cos(u2*.001));

				Vector3 coef=Vector3(-a1,a1+a2,-a2)*ONE_OVER_DX_2;
				A(i,i-1)=coef[0];
				A(i,i)=coef[1];
				A(i,i+1)=coef[2];}

			A(0,0)=(real)1;			////Dirichlet
			A(n-1,n-1)=(real)1;		////Dirichlet


			VectorX tar_old=tar;
			Solve(A,b,tar);
			VectorX r=tar-tar_old;
			real res=r.squaredNorm();
			//std::cout<<"iter "<<iter<<": "<<res<<std::endl;
			if(res<tol){converge=true;break;}}
		if(!converge)std::cout<<"Build_Data_Sample_5 not converge."<<std::endl;
	}
	////2D poisson equation
	////f=(ax)^2+ay^2,  lap=2a*a+2a, x,y=((0:31)+0.5)/32
	void Build_Data_Sample_6(VectorX& b, VectorX& tar, const int data_idx)
	{
		const int n = (int)b.size();
		real a = (real)(data_idx + 2);
		tar.resize(n);
		for (int i = 0; i < n; i++) {
			const VectorDi cell = grid.Cell_Coord(i);
			const VectorD  pos = grid.Center(cell);
			//std::cout << "[" << cell[0] << "," << cell[1] << "],pos=[" << pos[0] << "," << pos[1] << "]" << std::endl;
			if (grid.Is_Boundary_Cell(cell)) {
				b[i] = (real)pow(pos[0] * a, 2) + (real)a * pos[1] * pos[1];
			}
			else { b[i] = (real)-(2 * a * a + 2 * a); }
			tar[i] = (real)pow(pos[0] * a, 2) +  (real)a * pos[1] * pos[1];
		}
	}
	////f=(ax)^3+ay^2,  lap=6a*a*a*x+2a, x,y=((0:31)+0.5)/32
	void Build_Data_Sample_7(VectorX& b, VectorX& tar, const int data_idx)
	{
		const int n = (int)b.size();
		real a = (real)(data_idx + 2);
		tar.resize(n);
		for (int i = 0; i < n; i++) {
			const VectorDi cell = grid.Cell_Coord(i);
			const VectorD  pos = grid.Center(cell);
			//std::cout << "[" << cell[0] << "," << cell[1] << "],pos=[" << pos[0] << "," << pos[1] << "]" << std::endl;
			if (grid.Is_Boundary_Cell(cell)) {
				b[i] = (real)pow(pos[0] * a, 3) + (real)a * pos[1] * pos[1];
			}
			else { b[i] = (real)-(6 * a * a*a*pos[0] + 2 * a); }
			tar[i] = (real)pow(pos[0] * a, 3) + (real)a * pos[1] * pos[1];
		}
	}
	////f=sin(ax+ay),lapf=-2a*a*sin(ax+ay),x,y=((0:31)+0.5)/32 
	void Build_Data_Sample_8(VectorX& b, VectorX& tar, const int data_idx)
	{
		const int n = (int)b.size();
		real a = (real)(data_idx + 4)/4;
		tar.resize(n);
		for (int i = 0; i < n; i++) {
			const VectorDi cell = grid.Cell_Coord(i);
			const VectorD  pos = grid.Center(cell);
			//std::cout << "[" << cell[0] << "," << cell[1] << "],pos=[" << pos[0] << "," << pos[1] << "]" << std::endl;
			if (grid.Is_Boundary_Cell(cell)) {
				b[i] = (real)sin(a*pos[0]+a*pos[1]);
			}
			else { b[i] = (real)2*a*a*sin(a*pos[0]+a*pos[1]); }
			tar[i] = (real)sin(a * pos[0] + a * pos[1]);
		}
	}
	//Helmholtz Equation: lap u+u=0, xdomain(-8,8),ydomain=(-8,8), boundary condition:u(x,y)= -a/(x^2+y^2), grid_size=32*32
	void Build_Data_Sample_9(VectorX& b, VectorX& tar, const int data_idx)
	{
		const int n = (int)b.size();
		real a = (real)(data_idx + 4)/4;
		MatrixX A(n, n);
		A.setZero();
		b.fill((real)0);
		for (int i = 0; i < n; i++) {
			const VectorDi cell = grid.Cell_Coord(i);
			const VectorD  pos = grid.Center(cell);
			real x = pos[0]*16-8;
			real y = pos[1]*16-8;
			if (grid.Is_Boundary_Cell(cell)) { A(i, i) = 1; b[i] = -a/(x*x+y*y); }//Dirichlet
			else{
				A(i, i) =16*16*DX*DX-4 ;//4-1^2
				for (int c = 0; c < 4; c++) {
					VectorDi nb_cell = grid.Nb_C(cell, c);
					int nb_index = grid.Cell_Index(grid.Nb_C(cell, c));
					//std::cout << "center:[" << cell[0] << ", " << cell[1] << "], neighbour:[" << nb_cell[0] << "," << nb_cell[1] << "], index=" << nb_index << std::endl;
					A(i, nb_index) = (real)1;
				
				}
				const VectorD  pos = grid.Center(cell);
				real x = pos[0];
				real y = pos[1];
				b[i] = 0;
			}
		}
		Solve(A, b, tar);
	}

	//Helmholtz Equation: lap u+u=0, xdomain(-6,6),ydomain(-6,6), boundary condition:u(x,y)= a*(sin(0.02*x^2+0.02*y^2)),grid_size=32*32
	void Build_Data_Sample_10(VectorX& b, VectorX& tar, const int data_idx)
	{
		const int n = (int)b.size();
		real a = (real)(data_idx + 12) /8 ;
		MatrixX A(n, n);
		A.setZero();
		b.fill((real)0);
		for (int i = 0; i < n; i++) {
			const VectorDi cell = grid.Cell_Coord(i);
			const VectorD  pos = grid.Center(cell);
			real x = pos[0] * 12 - 6;
			real y = pos[1] * 12 - 6;
			if (grid.Is_Boundary_Cell(cell)) { A(i, i) = 1; b[i] = a*sin(0.02*x*x+0.02*y*y); }//Dirichlet
			else {
				A(i, i) = 144* DX * DX - 4;//4-1^2
				for (int c = 0; c < 4; c++) {
					VectorDi nb_cell = grid.Nb_C(cell, c);
					int nb_index = grid.Cell_Index(grid.Nb_C(cell, c));
					//std::cout << "center:[" << cell[0] << ", " << cell[1] << "], neighbour:[" << nb_cell[0] << "," << nb_cell[1] << "], index=" << nb_index << std::endl;
					A(i, nb_index) = (real)1;

				}
				const VectorD  pos = grid.Center(cell);
				real x = pos[0];
				real y = pos[1];
				b[i] = 0;
			}
		}
		Solve(A, b, tar);
	}
	////time series of wave function wtt=lap w
	void Build_Data_Sample_11(VectorX& b, VectorX& tar, const int data_idx)
	{
		const int n = (int)b.size();
		real a = (real)(data_idx + 1);
		MatrixX Wnm1(AXIS_SIZE, AXIS_SIZE);  Wnm1.setZero();//w at n-1
		MatrixX Wn(AXIS_SIZE, AXIS_SIZE);  Wn.setZero(); //w at n
		MatrixX Wnp1(AXIS_SIZE, AXIS_SIZE);  Wnp1.setZero();// w at n+1
		b.fill(0);
		real t = 0;
		real c = 1;
		real sigma = (real)1 / sqrt(30);
		//std::cout << "sigma" << sigma<<std::endl;
		real dt= sigma * (DX / c);
		std::cout << "dt="<<dt << std::endl;
		//std::cout << "dt=" << dt << std::endl;
		//real cfl = 100*dt * dt * ONE_OVER_DX_2 ;   //cfl=dt/dx, dx=6*DX
		//std::cout << AXIS_SIZE / 2 << std::endl;
		while (t<a*dt) {
			t = t + dt;
			//std::cout << "t=" << t << std::endl;
			//std::cout << Wn(AXIS_SIZE / 2 - 1, AXIS_SIZE / 2 - 1) << std::endl;
			//Wnm1 = Wn; Wn = Wnp1;
			//Wn(AXIS_SIZE / 2 - 1, AXIS_SIZE / 2 - 1) = (real)2.5 * sin(40 * (t));
			for (int i = 1; i < AXIS_SIZE-1; i++) {
				for (int j = 1; j < AXIS_SIZE-1; j++) {
					Wnm1(i, j) = Wn(i, j); Wn(i, j) = Wnp1(i, j);
					Wn(AXIS_SIZE/2-1, AXIS_SIZE/2-1) =(real) 2.5* sin(60 * (t));
					//std::cout << "source" << Wn(4, 4) << std::endl;
					//Wn(AXIS_SIZE / 2 - 1, AXIS_SIZE / 2 - 1) = dt * dt * 20 * sin(20 * pi * t);
					//if (i == 0 || i == AXIS_SIZE - 1 || j == 0 || j == AXIS_SIZE - 1) { Wn(i, j) = 0; }
					Wnp1(i, j) = 2 * Wn(i, j) - Wnm1(i, j) + sigma*sigma*(Wn(i + 1, j) + Wn(i - 1, j) + Wn(i, j - 1) + Wn(i, j + 1) - 4 * Wn(i, j));
						//std::cout << Wn(i, j) << std::endl;
				}
			}
		}
		/*std::cout<<"source:"<<Wn(AXIS_SIZE / 2 - 1, AXIS_SIZE / 2 - 1);
		for (int i = 0; i < AXIS_SIZE; i++) {
			for (int j = 0; j < AXIS_SIZE; j++) {
				std::cout << Wn(i, j) << "     ";
			}
			std::cout << std::endl;
		}*/
		for (int i = 0; i < n; i++) {
			VectorDi cell = grid.Cell_Coord(i);
			tar[i] = Wnp1(cell[0], cell[1]);
			if (grid.Is_Boundary_Cell(cell)) { b[i] = 0; }
			else b[i]= -Wn(cell[0], cell[1]);
		}
	}
	void Build_Data_Sample_12(VectorX& b, VectorX& tar, const int data_idx)
	{
		const int n = (int)b.size();
		real a = (real)(data_idx + 4) / 4;
		tar.resize(n);
		for (int i = 0; i < n; i++) {
			const VectorDi cell = grid.Cell_Coord(i);
			const VectorD  pos = grid.Center(cell);
			real b = a*random.Value(); ///b randomly generated
			//std::cout << "[" << cell[0] << "," << cell[1] << "],pos=[" << pos[0] << "," << pos[1] << "]" << std::endl;
			if (grid.Is_Boundary_Cell(cell)) {
				b[i] = (real)sin(a * pos[0] + a * pos[1]);
			}
			else { b[i] = (real)2 * a * a * sin(a * pos[0] + a * pos[1]); }
			tar[i] = (real)sin(a * pos[0] + a * pos[1]);
		}
	}
	//////////////////////////////////////////////////////////////////////////
	////analytical kernel implementations
	////1D kernel functions
	Vector3 C_A(const Vector3& x1,const VectorX& t)
	{
		real u1=(real).5*(x1[0]+x1[1]);
		real u2=(real).5*(x1[1]+x1[2]);
		real a1=(1.+t[0]*t[1]*pow(u1,2));
		real a2=(1.+t[0]*t[1]*pow(u2,2));
		return Vector3(-a1,a1+a2,-a2);
	}

	////calculate dCdti
	Vector3 dCdt_A(const Vector3& x1,const int i,const VectorX& t)
	{
		real u1=(real).5*(x1[0]+x1[1]);
		real u2=(real).5*(x1[1]+x1[2]);
		real a1=(i==0?t[1]:t[0])*pow(u1,2);
		real a2=(i==0?t[1]:t[0])*pow(u2,2);
		return Vector3(-a1,a1+a2,-a2);
	}

	Vector3 dCdx1_A(const Vector3& x1,const VectorX& t,const int i/*0-2*/)
	{
		real u1=(real).5*(x1[0]+x1[1]);
		real u2=(real).5*(x1[1]+x1[2]);
		real a1=t[0]*t[1]*u1;
		real a2=t[0]*t[1]*u2;	
		switch(i){
		case 0:{
			return Vector3(-a1,a1,(real)0);	
		}break;
		case 1:{
			return Vector3(-a1,a1+a2,-a2);
		}break;
		case 2:{
			return Vector3((real)0,a2,-a2);
		}break;
		default:return Vector3::Zero();}
	}
	

	////TODO: 2D kernel functions, may not need this for now
	Vector5 C_A(const Vector5& x1,const VectorX& t)
	{
		////TODO
		return Vector5::Zero();
	}

	////calculate dCdti
	Vector5 dCdt_A(const Vector5& x1,const int i,const VectorX& t)
	{
		////TODO
		return Vector5::Zero();
	}

	Vector5 dCdx1_A(const Vector5& x1,const VectorX& t,const int i/*0-2*/)
	{
		////TODO
		return Vector5::Zero();
	}
};

template<int d> class NeuralKernel
{public:
	using VecD=typename If<d==1,Vector3,Vector5>::Type;			////type for x
	using VecX=typename If<d==1,Vector3,VectorX>::Type;			////type for p, with dimension dim_of_x*d
	Typedef_VectorDii(d);

	bool use_kernel_p=USE_KERNEL_P;
	bool use_kernel_x=USE_KERNEL_X;
	bool use_kernel_symmetry=USE_KERNEL_SYMMETRY;

	//////////////////////////////////////////////////////////////////////////
	//network kernel implementation
	//the following three functions are 1D implementations only!
	Vector3 C_N(const Vector3& x1,const Vector3& p1,const VectorX& t,SimpleNetwork* nn)
	{
		if(use_kernel_x&&use_kernel_p){Vector6 x(x1[0],x1[1],x1[2],p1[0],p1[1],p1[2]);nn->Set_X(&x[0]);}
		else if(use_kernel_p){Vector6 x((real)0.,(real)0.,(real)0.,p1[0],p1[1],p1[2]);nn->Set_X(&x[0]);}
		else if(use_kernel_x){nn->Set_X(&x1[0]);}

		nn->Set_t(&t[0]);
		nn->Forward();
		Vector3 v;
		if(!use_kernel_symmetry)for(int i=0;i<nn->nC;i++)v[i]=nn->x[nn->L-1][i];
		else{v[0]=v[2]=nn->x[nn->L-1][0];v[1]=nn->x[nn->L-1][1];}	////symmetric kernel
		return v;
	}

	////calculate dCdti
	Vector3 dCdt_N(const Vector3& x1,const Vector3& p1,const int t_idx,const VectorX& t,SimpleNetwork* nn)
	{
		if(use_kernel_x&&use_kernel_p){Vector6 x(x1[0],x1[1],x1[2],p1[0],p1[1],p1[2]);nn->Set_X(&x[0]);}
		else if(use_kernel_p){Vector6 x((real)0.,(real)0.,(real)0.,p1[0],p1[1],p1[2]);nn->Set_X(&x[0]);}
		else if(use_kernel_x){nn->Set_X(&x1[0]);}

		nn->Set_t(&t[0]);
		nn->dCdt();
		Vector3 v;
		if(!use_kernel_symmetry)for(int i=0;i<nn->nC;i++)v[i]=nn->Ct(i,t_idx);
		else{v[0]=v[2]=nn->Ct(0,t_idx);v[1]=nn->Ct(1,t_idx);}	////symmetric kernel
		return v;
	}

	Vector3 dCdx1_N(const Vector3& x1,const Vector3& p1,const VectorX& t,const int t_idx/*0-2*/,SimpleNetwork* nn)
	{
		if(use_kernel_x&&use_kernel_p){Vector6 x(x1[0],x1[1],x1[2],p1[0],p1[1],p1[2]);nn->Set_X(&x[0]);}
		else if(use_kernel_p){Vector6 x((real)0.,(real)0.,(real)0.,p1[0],p1[1],p1[2]);nn->Set_X(&x[0]);}
		else if(use_kernel_x){nn->Set_X(&x1[0]);}

		nn->Set_t(&t[0]);
		nn->dCdX();
		Vector3 v;
		if(!use_kernel_symmetry)for(int i=0;i<nn->nC;i++)v[i]=nn->g[0](i,t_idx);
		else{v[0]=v[2]=nn->g[0](0,t_idx);v[1]=nn->g[0](1,t_idx);}	////symmetric kernel
		return v;
	}	

	////DONE: 2D implementations
	Vector5 C_N(const Vector5& x1,const VectorX& p1,const VectorX& t,SimpleNetwork* nn)
	{
		if(use_kernel_x&&use_kernel_p){
			ArrayF<real,15> x={x1[0],x1[1],x1[2],x1[3],x1[4],
				p1[0],p1[1],p1[2],p1[3],p1[4],p1[5],p1[6],p1[7],p1[8],p1[9]};
			nn->Set_X(&x[0]);}
		else if(use_kernel_p){
			ArrayF<real,15> x={(real)0.,(real)0.,(real)0.,(real)0.,(real)0.,
				p1[0],p1[1],p1[2],p1[3],p1[4],p1[5],p1[6],p1[7],p1[8],p1[9]};
			nn->Set_X(&x[0]);}
		else if(use_kernel_x){nn->Set_X(&x1[0]);}

		nn->Set_t(&t[0]);
		nn->Forward();
		Vector5 v;
		if(!use_kernel_symmetry)for(int i=0;i<nn->nC;i++)v[i]=nn->x[nn->L-1][i];
		else{
			v[0]=v[4]=nn->x[nn->L-1][0];
			v[1]=v[3]=nn->x[nn->L-1][1];
			v[2]=nn->x[nn->L-1][2];
		}	////symmetric kernel
		return v;
	}

	////calculate dCdti
	Vector5 dCdt_N(const Vector5& x1,const VectorX& p1,const int t_idx,const VectorX& t,SimpleNetwork* nn)
	{
		if(use_kernel_x&&use_kernel_p){
			ArrayF<real,15> x={x1[0],x1[1],x1[2],x1[3],x1[4],
				p1[0],p1[1],p1[2],p1[3],p1[4],p1[5],p1[6],p1[7],p1[8],p1[9]};
			nn->Set_X(&x[0]);}
		else if(use_kernel_p){
			ArrayF<real,15> x={(real)0.,(real)0.,(real)0.,(real)0.,(real)0.,
				p1[0],p1[1],p1[2],p1[3],p1[4],p1[5],p1[6],p1[7],p1[8],p1[9]};
			nn->Set_X(&x[0]);}
		else if(use_kernel_x){nn->Set_X(&x1[0]);}

		nn->Set_t(&t[0]);
		nn->dCdt();
		Vector5 v;
		if(!use_kernel_symmetry)for(int i=0;i<nn->nC;i++)v[i]=nn->Ct(i,t_idx);
		else{
			//v[0]=v[2]=nn->Ct(0,t_idx);v[1]=nn->Ct(1,t_idx);
			v[0]=v[4]=nn->Ct(0,t_idx);
			v[1]=v[3]=nn->Ct(1,t_idx);
			v[2]=nn->Ct(2,t_idx);
		}	////symmetric kernel
		return v;
	}

	Vector5 dCdx1_N(const Vector5& x1,const VectorX& p1,const VectorX& t,const int t_idx/*0-2*/,SimpleNetwork* nn)
	{
		if(use_kernel_x&&use_kernel_p){
			ArrayF<real,15> x={x1[0],x1[1],x1[2],x1[3],x1[4],
				p1[0],p1[1],p1[2],p1[3],p1[4],p1[5],p1[6],p1[7],p1[8],p1[9]};
			nn->Set_X(&x[0]);}
		else if(use_kernel_p){
			ArrayF<real,15> x={(real)0.,(real)0.,(real)0.,(real)0.,(real)0.,
				p1[0],p1[1],p1[2],p1[3],p1[4],p1[5],p1[6],p1[7],p1[8],p1[9]};
			nn->Set_X(&x[0]);}
		else if(use_kernel_x){nn->Set_X(&x1[0]);}

		nn->Set_t(&t[0]);
		nn->dCdX();
		Vector5 v;
		if(!use_kernel_symmetry)for(int i=0;i<nn->nC;i++)v[i]=nn->g[0](i,t_idx);
		else{
			v[0]=v[4]=nn->g[0](0,t_idx);
			v[1]=v[3]=nn->g[0](1,t_idx);
			v[2]=nn->g[0](2,t_idx);
		}	////symmetric kernel
		return v;
	}	
};

template<int d> class NeuralPDE  
{public:
	using VecD=typename If<d==1,Vector3,Vector5>::Type;			////type for x
	using VecX=typename If<d==1,Vector3,VectorX>::Type;			////type for p, with dimension dim_of_x*d
	Typedef_VectorDii(d);

	const Vector<int,d> counts=Vector<int,d>::Ones()*AXIS_SIZE;		////grid size
	const int n=counts.prod();										////number of unknowns
	Grid<d> grid;

	int L=SOLVER_ITERATION_NUM;										////layer number, number of Picard iterations
	VectorX b;														////nx1 vector, rhs
	VectorX tar;													////nx1 vector, target values
	Array<VectorX> x;												////L-size array, each nx1 vector, forward value
	Array<MatrixX> g;												////(L)-size array, each nxn matrix, backward derivatives
	
	SimpleNetwork* nn=nullptr;
	PDEExample<d>* example=nullptr;
	NeuralKernel<d>* kernel=nullptr;

	////random samples
	bool use_random_sample=USE_RANDOM_SAMPLE;
	int sample_fraction=SAMPLE_FRACTION;			////fraction=1/sample_fraction
	Array<int> rs_idx;								////random sample index

	//////////////////////////////////////////////////////////////////////////
	////initialization
	void Initialize()
	{
		grid.Initialize(counts,DX);

		x.resize(L);
		for(int i=0;i<L;i++){
			x[i].resize(n);}
		x[0].fill((real)1);					////initial guess of x
		//for(int i=0;i<n;i++){
		//	x[0][i]=(real)(rand()%20000-10000)/(real)10000*sqrt(6.0/(double)(n));}

		g.resize(L);
		for(int i=0;i<L;i++){
			g[i].resize(n,n);
			g[i].setIdentity();}

		b.resize(n);b.fill((real)0);
		tar.resize(n);tar.fill((real)0);

		////random sample
		if(use_random_sample){
			int n_s=n/sample_fraction;
			for(int i=0;i<n_s;i++){
				int idx=rand()%n;
				rs_idx.push_back(idx);}
			std::cout<<"rs: ";
			for(int i=0;i<rs_idx.size();i++)
				std::cout<<rs_idx[i]<<", ";std::cout<<std::endl;}
	}
	void update_random_sample() {
		rs_idx.clear();
		int n_s = n / sample_fraction;
		for (int i = 0; i < n_s; i++) {
			int idx = rand() % n;
			rs_idx.push_back(idx);
		}
	}

	//////////////////////////////////////////////////////////////////////////
	////kernel function and its derivatives
	////network kernel implementations
	VecD C(const VecD& x1,const VecX& p1,const VectorX& t)
	{
		if(nn)return kernel->C_N(x1,p1,t,nn);
		//else return example->C_A(x1,t);
	}

	////calculate dCdti
	VecD dCdt(const VecD& x1,const VecX& p1,const int i,const VectorX& t)
	{
		if(nn)return kernel->dCdt_N(x1,p1,i,t,nn);
		//else return example->dCdt_A(x1,i,t);
	}

	VecD dCdx1(const VecD& x1,const VecX& p1,const VectorX& t,const int i/*0-2*/)
	{
		if(nn)return kernel->dCdx1_N(x1,p1,t,i,nn);
		else return example->dCdx1_A(x1,t,i);
	}

	////assembling A with a dimension-dependent kernel
	void Build_A(MatrixX& A,const VectorX& x1,const VectorX& t)
	{
		////1D A
		if constexpr (d==1){
			A.setZero();
			for(int i=1;i<n-1;i++){
				VecD coef=C(VecD(x1[i-1],x1[i],x1[i+1]),Node_Pos(i),t);
				A(i,i-1)=coef[0];
				A(i,i)=coef[1];
				A(i,i+1)=coef[2];}
			A(0,0)=(real)1;			////Dirichlet
			A(n-1,n-1)=(real)1;		////Dirichlet		
		}
		else if constexpr (d==2){
			A.setZero();
			for(int i=0;i<n;i++){
				const VectorDi cell=grid.Cell_Coord(i);
				if(!grid.Is_Boundary_Cell(cell)){
					VecD x=Node_Val(cell,x1);
					VecX p=Node_Pos(cell);
					VecD coef=C(x,p,t);
					//std::cout << "Build_A" << std::endl;
					Set_Node_Val(cell,coef,A);}
				else{
					A(i,i) = (real)1;
					}}
			//std::cout << "A:" << A<< std::endl;
		}
	}

	//////////////////////////////////////////////////////////////////////////
	////Neural network
	void Forward(const VectorX& input,const VectorX& t,VectorX& output)
	{
		MatrixX A(n,n);
		Build_A(A,input,t);
		Solve(A,b,output);
	}

	void Backward(const MatrixX& input,const VectorX& x1,const VectorX& t,MatrixX& output)
	{
		MatrixX x2x1(n,n);
		dx2dx1(x1,b,t,x2x1);
		output=input*x2x1;
	}

	real Loss_total(const VectorX& x,const VectorX& tar)
	{
		real loss=(real)0;
#pragma omp parallel for 
			for(int i=0;i<n;i++)loss+=pow(x[i]-tar[i],2);
			//loss = (real)sqrt(loss) / n;
		return loss;
	}
	real Loss_rs(const VectorX& x, const VectorX& tar)
	{
		real loss = (real)0;
#pragma omp parallel for 
			//std::cout << "rs index:" << std::endl;
			for (int i = 0; i < rs_idx.size(); i++) {
				//std::cout << rs_idx[i] << "   ";
				loss += pow(x[rs_idx[i]] - tar[rs_idx[i]], 2);
			}
			//std::cout<<std::endl;
			//loss = (real)sqrt(loss) / rs_idx.size();
		return loss;
	}

	real Loss_total(const VectorX& t)
	{
#pragma omp parallel for 
		for(int i=0;i<L-1;i++){
			Forward(x[i],t,x[i+1]);}
		//std::cout<<"PDE Loss: "<<Loss(x[L-1],tar)<<std::endl;
		return Loss_total(x[L - 1], tar);
	}
	real Loss_rs(const VectorX& t)
	{
#pragma omp parallel for 
		for (int i = 0; i < L - 1; i++) {
			Forward(x[i], t, x[i + 1]);
		}
		//std::cout<<"PDE Loss: "<<Loss(x[L-1],tar)<<std::endl;
		return Loss_rs(x[L - 1], tar);
	}

	//////////////////////////////////////////////////////////////////////////
	////manual derivatives!
	void dLdx(const VectorX& x,const VectorX& tar,VectorX& Lx)
	{
		Lx.fill((real)0);
		if(!use_random_sample)
#pragma omp parallel for 
			for(int i=0;i<n;i++)Lx[i]=(real)2*(x[i]-tar[i]);
		else{
			//real mean = 0;
#pragma omp parallel for 
			for (int i = 0; i < rs_idx.size(); i++) { 
				Lx[rs_idx[i]] = (real)2 * (x[rs_idx[i]] - tar[rs_idx[i]]); 
				//mean +=
			}
		}
	}

	void dLdt(const VectorX& t,VectorX& Lt)
	{
		for(int i=0;i<L-1;i++){Forward(x[i],t,x[i+1]);}		
		for(int i=L-1;i>0;i--){Backward(g[i],x[i-1],t,g[i-1]);}
		const int nt=(int)t.size();
		Lt.fill((real)0);
		VectorX Lx(n);
		dLdx(x[L-1],tar,Lx);
#pragma omp parallel for 
		for(int i=L-1;i>0;i--){
			////Layer i-1, from x[i-1] (x1) to x[i] (x2)
			MatrixX x2t(n,nt);dx2dt(x[i-1],b,t,x2t);
			Lt+=Lx.transpose()*g[i]*x2t;}
	}

	////A(x,t), calculate dAdt
	void dAdt(const VectorX& x1,const VectorX& t,const int idx,MatrixX& dAdt)
	{
		if constexpr (d==1){
			dAdt.setZero();
#pragma omp parallel for 
			for(int i=1;i<n-1;i++){
				VecD coef=dCdt(VecD(x1[i-1],x1[i],x1[i+1]),Node_Pos(i),idx,t);
				dAdt(i,i-1)=coef[0];
				dAdt(i,i)=coef[1];
				dAdt(i,i+1)=coef[2];}
			dAdt(0,0)=(real)0;
			dAdt(n-1,n-1)=(real)0;}
		else if constexpr (d==2){
			dAdt.setZero();
#pragma omp parallel for 
			for(int i=0;i<n;i++){
				const VectorDi cell=grid.Cell_Coord(i);
				if (!grid.Is_Boundary_Cell(cell)) {
					VecD x = Node_Val(cell, x1);
					VecX p = Node_Pos(cell);
					VecD coef = dCdt(x, p, idx, t);
					//std::cout << "dAdt" << std::endl;
					Set_Node_Val(cell, coef, dAdt);
				}
				else{
					dAdt(i,i) = (real)0;}}
		}
	}

	////A(x1,t), calculate dAdx1
	void dAdx1(const VectorX& x1,const VectorX& t,const int idx,MatrixX& dAdx)
	{
		if constexpr (d==1){
			dAdx.setZero();
			for(int c=-1;c<=1;c++){int i=idx+c;
				if(i<=0||i>=n-1)continue;
				VecD coef=dCdx1(VecD(x1[i-1],x1[i],x1[i+1]),Node_Pos(i),t,1-c);	////convolution needs to reverse local index
				dAdx(i,i-1)=coef[0];
				dAdx(i,i)=coef[1];
				dAdx(i,i+1)=coef[2];}		
		}
		else if constexpr (d==2){
			dAdx.setZero();
			const Vector2i cell=grid.Cell_Coord(idx);
			//std::cout << idx << "=[" << cell[0] << "," << cell[1] << "]" << std::endl;
			for(int c=0;c<5;c++){
				Vector2i nb;
				if (c < 2) nb = grid.Nb_C(cell, c);	////this function returns the nb cell coord following the convention of left-bottom-center-up-right
				if (c == 2) nb = cell;
				if (c > 2) nb = grid.Nb_C(cell, c - 1);	////this function returns the nb cell coord following the convention of left-bottom-center-up-right
				if (grid.Is_Boundary_Cell(cell))continue;
				if(grid.Is_Boundary_Cell(nb))continue;
					VecD x=Node_Val(nb,x1);
					VecX p=Node_Pos(nb);
					VecD coef=dCdx1(x,p,t,4-c);
					//std::cout << "dAdx1" << std::endl;
					Set_Node_Val(nb,coef,dAdx);}
		}
	}

	////A(x1,t)x2=b, calculate dx2dt
	void dx2dt(const VectorX& x1,const VectorX& b,const VectorX& t,/*dx2dt*/MatrixX& x2t_m)
	{
		MatrixX A(n,n);
		Build_A(A,x1,t);
		VectorX x2(n);
		Solve(A,b,x2);

		const int nt=(int)t.size();
		MatrixX b2_m(n,nt);
#pragma omp parallel for 
		for(int i=0;i<nt;i++){
			VectorX x2t(n);
			MatrixX at(n,n);
			dAdt(x1,t,i,at);
			b2_m.col(i)=-at*x2;}
		Solve(A,b2_m,x2t_m);
	}

	////A(x1,t)x2=b, calculate dx2dx1
	void dx2dx1(const VectorX& x1,const VectorX& b,const VectorX& t,/*dx2dt*/MatrixX& x2x1)
	{
		MatrixX A(n,n);
		Build_A(A,x1,t);
		VectorX x2(n);
		Solve(A,b,x2);

		MatrixX B2(n,n);
#pragma omp parallel for 
		for(int j=0;j<n;j++){
			MatrixX ax(n,n);
			dAdx1(x1,t,j,ax);
			B2.col(j)=-ax*x2;}
		
		Solve(A,B2,x2x1);
	}

	////1D auxiliary functions
	Vector3 Node_Val(const int i,const VectorX& x1)
	{return Vector3(x1[i-1],x1[i],x1[i+1]);}

	void Set_Node_Val(const int i,const Vector3& coef,MatrixX& dAdt)
	{
		dAdt(i,i-1)=coef[0];
		dAdt(i,i)=coef[1];
		dAdt(i,i+1)=coef[2];
	}

	Vector3 Node_Pos(const int i)
	{return Vector3((real)(i-1)*DX,(real)i*DX,(real)(i+1)*DX);}

	////2D auxiliary functions
	Vector5 Node_Val(const Vector2i& cell,const VectorX& x1)
	{
		const int c0=grid.Cell_Index(Vector2i(cell[0]-1,cell[1]));
		const int c1=grid.Cell_Index(Vector2i(cell[0],cell[1]-1));
		const int c2=grid.Cell_Index(Vector2i(cell[0],cell[1]));
		const int c3=grid.Cell_Index(Vector2i(cell[0],cell[1]+1));
		const int c4=grid.Cell_Index(Vector2i(cell[0]+1,cell[1]));
		return Vector5 (x1[c0],x1[c1],x1[c2],x1[c3],x1[c4]);
	}

	void Set_Node_Val(const Vector2i& cell,const Vector5& coef,MatrixX& dAdt)
	{
		//std::cout << "[" << cell[0] << "," << cell[1]  << "],coef="<<coef<<std::endl;
		const int c0=grid.Cell_Index(Vector2i(cell[0]-1,cell[1]));
		const int c1=grid.Cell_Index(Vector2i(cell[0],cell[1]-1));
		const int c2=grid.Cell_Index(Vector2i(cell[0],cell[1]));
		const int c3=grid.Cell_Index(Vector2i(cell[0],cell[1]+1));
		const int c4=grid.Cell_Index(Vector2i(cell[0]+1,cell[1]));
		const int i=c2;

		//std::cout << dAdt(i, c0) << "| 0" << std::endl;
		dAdt(i,c0)=coef[0];
		//std::cout << dAdt(i, c0) << "| 0" << std::endl;
		//std::cout << dAdt(i, c1) << "| 1" << std::endl;
		dAdt(i,c1)=coef[1];
		//std::cout << dAdt(i, c1) << "| 1" << std::endl;
		//std::cout << dAdt(i, c2) << "| 2" << std::endl;
		dAdt(i,c2)=coef[2];
		//std::cout << dAdt(i, c2) << "| 2" << std::endl;
		//std::cout << dAdt(i, c3) << "| 3" << std::endl;
		dAdt(i,c3)=coef[3];
		//std::cout << dAdt(i, c3) << "| 3" << std::endl;
		//std::cout << dAdt(i, c4) << "| 4" << std::endl;
		dAdt(i,c4)=coef[4];
		//std::cout << dAdt(i, c4) << "| 4" << std::endl;
	}

	VectorX Node_Pos(const Vector2i& cell)
	{
		VectorX v(10);
		v[0]=DX*(cell[0]-1);
		v[1]=DX*(cell[1]);
		v[2]=DX*(cell[0]);
		v[3]=DX*(cell[1]-1);
		v[4]=DX*(cell[0]);
		v[5]=DX*(cell[1]);
		v[6]=DX*(cell[0]);
		v[7]=DX*(cell[1]+1);
		v[8]=DX*(cell[0]+1);
		v[9]=DX*(cell[1]);
		return v;
	}
};

class NeuralPDEOptimizer 
: public OptimizerIpOpt{using Base=OptimizerIpOpt;
//: public OptimizerMMA{using Base=OptimizerMMA;
public:
	const int data_s=DATA_SAMPLE_SIZE;		////data sample size
	int nt=2;								////default kernel parameter number (analytical)
	VectorX t;								////kernel parameters, for optimization
	VectorX Lt;								////dLdt
	real lr = 1e-3;
	Array<NeuralPDE<DIMENSION> > pdes;				////data_s

	bool use_network=true;
	SimpleNetwork nn;
	NeuralKernel<DIMENSION> kernel;
	PDEExample<DIMENSION> training_set;
	PDEExample<DIMENSION> testing_set;

	bool use_loss_min=true;
	VectorX t_for_loss_min;
	real loss_min=(real)FLT_MAX;

	std::string output_dir="output";
	Array<VectorX> pred_values;
	Array<VectorX> true_values;
	
	void Initialize()
	{
		Initialize_Network();

		////data init for each pde sample
		pdes.resize(data_s);
		for(int i=0;i<pdes.size();i++){
			pdes[i].Initialize();
			if(use_network){
				pdes[i].nn=&nn;
				pdes[i].kernel=&kernel;}
			pdes[i].example=&training_set;
			training_set.Build_Data_Sample(pdes[i].b,pdes[i].tar,i);}

		Initialize_Optimizer();
		////check dLdt
		//dLdt(t,Lt);
		//std::cout<<"dLdt: "<<Lt.transpose()<<std::endl;

		//numerical check
		real loss=Loss(t);
		std::cout << "initial loss=" << loss << std::endl;
		real dt=0.00001;
		//VectorX Lt1(nt);
		//for(int i=0;i<nt;i++){
		//	VectorX t1=t;t1[i]+=dt;
		//	real loss1=Loss(t1);
		//	Lt1[i]=(loss1-loss)/dt;
		//	//std::cout<<"i: "<<i<<", loss1: "<<loss1<<", loss: "<<loss<<std::endl;
		//}
		//std::cout<<"dLdt1: "<<Lt1.transpose()<<std::endl;
	}

	void Initialize_Network()
	{
		//Array<int> size={3,8,8,16,16,8,8,3};
		//Array<int> type={0,1,0,1,0,1,0};
		//Array<int> size={3,4,4,3};
		//Array<int> type={0,1,0};
		//Array<int> size={6,8,8,8,8,3};
		//Array<int> type={0,1,0,1,0};
		//Array<int> size={3,2};
		//Array<int> type={0};
		Array<int> size=NN_SIZE;
		Array<int> type=NN_TYPE;

		nn.Initialize(size,type);
		nn.Print_t();
	}

	void Initialize_Optimizer()
	{
		if(use_network)nt=nn.nt;
		Lt.resize(nt);Lt.fill((real)0);

		t.resize(nt);t.fill((real)1);		////initial guess of t
		if(use_network)nn.Get_t(&t[0]);

		////Opt initialization
		n_var=nt;
		n_cons=0;
		//step_size=(real)0.02;	////only useful for mma
		var_lb=(real)-1e5; 
		var_ub=(real)1e5; 
		tol=IPOPT_TOL;
		Allocate_Data();
		Sync_Var_NN_To_Opt(var);
	}

	//////////////////////////////////////////////////////////////////////////
	////Opt APIs
	void Sync_Var_Opt_To_NN(const real* var)
	{
		std::memcpy(&t[0],var,n_var*sizeof(real));
	}

	void Sync_Var_NN_To_Opt(real* var)
	{
		std::memcpy(var,&t[0],n_var*sizeof(real));
	}

	void Sync_Grad_NN_To_Opt(real* grad)
	{
		std::memcpy(grad,&Lt[0],n_var*sizeof(real));
	}

	virtual real Compute_Objective(const real* var)
	{
		Sync_Var_Opt_To_NN(var);
		real loss=Loss(t);

		std::cout<<"Loss: "<<loss;
		if(use_loss_min&&loss<loss_min){
			loss_min=loss;
			t_for_loss_min=t;
			std::cout<<", min-----";
		}
		std::cout<<std::endl;

		return loss;
	}

	virtual void Compute_Gradient(const real* var,real* grad)
	{
		Sync_Var_Opt_To_NN(var);
		dLdt(t,Lt);
		//std::cout<<"Lt: ";
		//for(int i=0;i<nt;i++)std::cout<<Lt[i]<<", ";
		//std::cout<<std::endl;
		Sync_Grad_NN_To_Opt(grad);
	}

	virtual void Optimize()
	{
		USE_RANDOM_SAMPLE = false;
		real loss = Loss(t);
		while (loss > 1e-4) {
			Initialize_Optimizer();
			//Validation();
			Base::Optimize();
			Sync_Var_Opt_To_NN(intmed_var.data());	////only for IpOpt
			std::cout << "Optimized t: " << t.transpose() << std::endl;
			loss = Loss(t);
			std::cout << "Training loss: " << loss << std::endl;
			std::cout << "tol=" << tol << std::endl;

			//for(int i=0;i<data_s;i++)
			//	std::cout<<"Optimized x: "<<pdes[i].x[pdes[i].L-1].transpose()<<std::endl;

			if (use_loss_min && loss > loss_min) {
				t = t_for_loss_min;
				std::cout << "Use loss min: " << loss_min << std::endl;
				std::cout << "tol=" << tol<<std::endl;
			}
			//Validation();
			if (loss_min > 1e-4) {
				Optimize_GD_Without_Ipopt();
				Sync_Var_NN_To_Opt(var);
				USE_RANDOM_SAMPLE = false;
				loss = Loss(t);
			}
			std::cout << "loss=" << loss << std::endl;
		}
		Validation();
	}
	virtual void Optimize_GD_Without_Ipopt() {
		std::cout << "~~~~~~~~using GD~~~~~~~~~" << std::endl;
		USE_RANDOM_SAMPLE = true;
		real loss = 100;
		real loss_pre = 100;
		std::cout << "tol=" << tol << std::endl;
		int iter = 0;
		
		//// SGD momentum only
		VectorX m ;
		m.resize(nt);
		m.fill(0);
		
		//// Adagrad
		VectorX v;
		v.resize(nt);
		v.fill(0);

		real beta_1 = 0.9;
		real beta_2 = 0.999;

		while (iter<20) {
			for (int j = 0; j < data_s; j++) {
				pdes[j].update_random_sample();
			}
			iter++;
			dLdt(t, Lt);
			//loss_pre = Loss(t);
			//std::cout << "m=:" << std::endl;
#pragma omp parallel for 
			for (int i = 0; i < nt; i++) {
				//std::cout << m[i] << "," << std::endl;
				
				///// SGD momentum
				/*m[i] = 0.9 * m[i] + lr * Lt[i];
				t[i] = t[i] - m[i]; */

				// SGD
				//t[i] = t[i] - lr * Lt[i];

				//// Adagrad
				/*v[i] += Lt[i] * Lt[i];
				t[i] = t[i] - lr * Lt[i] / (sqrt(v[i]+1e-8));*/

				////Adam
				m[i] = beta_1 * m[i] + (1 - beta_1) * Lt[i];
				v[i] = beta_2 * v[i] + (1 - beta_2) * Lt[i] * Lt[i];
				real m_hat = m[i] / (1 - pow(beta_1,iter));
				real v_hat = v[i] / (1 - pow(beta_2, iter));
				real alpha = lr * sqrt(1 - pow(beta_2, iter)) / (1 - pow(beta_1, iter));
				//real alpha = lr / (sqrt(iter));
				t[i] = t[i] - alpha* m_hat / (sqrt(v_hat) + 1e-8);

			}
			loss = 0;
#pragma omp parallel for 
			for (int i = 0; i < data_s; i++) loss += pdes[i].Loss_total(t); 
			real loss_rs = Loss(t);
			//if (iter % 500 == 0) { std::cout << "decay" << std::endl; lr = lr * 0.1; }
			//if (abs(loss - loss_pre) < 1e-1) { lr = lr * 1e-1; }
			//lr = lr * 0.999;
			//std::cout << "lr=" << lr << std::endl;
			std::cout <<"total loss="<<loss<< ", rs loss=" << loss_rs << std::endl;
		}


	}

	real Loss(const VectorX& t)
	{
		real loss=(real)0;
#pragma omp parallel for 
		for(int i=0;i<data_s;i++){
			if(USE_RANDOM_SAMPLE)	loss+=pdes[i].Loss_rs(t);
			else loss += pdes[i].Loss_total(t);
		}
		return loss;
	}


	void dLdt(const VectorX& t,VectorX& Lt)
	{
		Lt.fill((real)0);
#pragma omp parallel for 
		for(int i=0;i<data_s;i++){
			VectorX Lti(nt);
			pdes[i].dLdt(t,Lti);
			Lt+=Lti;
		}
	}

	void Validation()
	{
		std::cout<<"------------------start validation------------------"<<std::endl;
		NeuralPDE<DIMENSION> pde;
		pde.use_random_sample=false;
		pde.Initialize();
		pde.nn=&nn;
		pde.kernel=&kernel;
		pde.example=&testing_set;
		
		int test_s=std::max(16,data_s);
		std::cout << "~~~~~" << std::endl;
		for(int i=0;i<test_s;i++){
			VectorX b(pde.n),tar(pde.n);
			//std::cout << "n=" << pde.n << std::endl;
			testing_set.Build_Data_Sample(b,tar,i+data_s);
			std::cout<<"test loss"<<std::endl;
			pde.b=b;
			pde.tar=tar;
			real loss=pde.Loss_total(t);
			std::cout<<"Testing loss: "<<loss<<std::endl;

			true_values.push_back(tar);
			pred_values.push_back(pde.x[pde.L-1]);}

		for(int i=0;i<test_s;i++){
			Write_Output_Files(i);}
	}

	void Write_Output_Files(const int frame)
	{
		if(!File::Directory_Exists(output_dir.c_str()))
			File::Create_Directory(output_dir);

		std::string frame_dir=output_dir+"/"+std::to_string(frame);
		if(!File::Directory_Exists(frame_dir.c_str()))File::Create_Directory(frame_dir);
		
		{std::string file_name=output_dir+"/0/last_frame.txt";
		File::Write_Text_To_File(file_name,std::to_string(frame));}

		////write true values
		{Particles<2> particles;
		SegmentMesh<2> curves;
		int n=(int)true_values[0].size();
		particles.Resize(n);
		curves.Vertices().resize(n);
		for(int i=0;i<n;i++){
			particles.X(i)=Vector2((real)i*DX,true_values[frame][i]);
			curves.Vertices()[i]=Vector2((real)i*DX,true_values[frame][i]);
			if(i>0)curves.elements.push_back(Vector2i(i-1,i));}
		std::string file_name=frame_dir+"/p1";
		particles.Write_To_File_3d(file_name);
		file_name=frame_dir+"/s1";
		curves.Write_To_File_3d(file_name);
		file_name=output_dir+"/0/true_"+std::to_string(frame)+".txt";
		Array<real> values;for(int i=0;i<n;i++)values.push_back(true_values[frame][i]);
		File::Write_Text_Array_To_File(file_name,values,n);
		}

		////write predict values
		{Particles<2> particles;
		SegmentMesh<2> curves;
		int n=(int)true_values[0].size();
		particles.Resize(n);
		curves.Vertices().resize(n);
		for(int i=0;i<n;i++){
			particles.X(i)=Vector2((real)i*DX,pred_values[frame][i]);
			curves.Vertices()[i]=Vector2((real)i*DX,pred_values[frame][i]);
			if(i>0)curves.elements.push_back(Vector2i(i-1,i));}
		std::string file_name=frame_dir+"/p2";
		particles.Write_To_File_3d(file_name);
		file_name=frame_dir+"/s2";
		curves.Write_To_File_3d(file_name);
		file_name=output_dir+"/0/pred_"+std::to_string(frame)+".txt";
		Array<real> values;for(int i=0;i<n;i++)values.push_back(pred_values[frame][i]);
		File::Write_Text_Array_To_File(file_name,values,n);}

		std::cout<<"write to "<<frame_dir<<std::endl;
	}
};

#endif
