//#####################################################################
// Topology optimization SIMP driver
// Copyright (c) (2018-), Bo Zhu, boolzhu@gmail.com
// This file is part of SLAX, whose distribution is governed by the LICENSE file.
//#####################################################################
#ifndef __NeuralPDEDriver_h__
#define __NeuralPDEDriver_h__
#include "Driver.h"
#include "NeuralPDE.h"
#include "NeuralNetwork.h"

class NeuralPDEDriver : public Driver
{using Base=Driver;
public:
	NeuralPDEOptimizer pde_opt;

	NeuralPDEDriver(){}
	
	void Initialize()
	{
		//SimpleNetwork nn;
		//Array<int> size={3,3,3,3};
		//Array<int> type={0,1,0};
		//nn.Initialize(size,type);
		//VectorX input(3);input.fill((real)1);input[0]=-1.;
		//VectorX output(3);output.fill((real)0);
		//nn.Set_X(&input[0]);
		//std::cout<<"input: "<<input.transpose()<<std::endl;

		//nn.Forward();
		//nn.Get_Output(output);
		//std::cout<<"output: "<<output.transpose()<<std::endl;

		//nn.dCdt();
		//std::cout<<"Ct:\n"<<nn.Ct<<std::endl;
		//nn.Check_Numerical_Derivatives();

		pde_opt.Initialize();				

		//NeuralPDE pde_tar;
		//pde_tar.use_network=false;
		//pde_tar.L=20;
		//pde_tar.Initialize();
		//pde_tar.t[0]=pde_tar.t[1]=(real)1.;
		//pde_tar.Loss(pde_tar.t);
		//pde.tar=pde_tar.x[pde_tar.L-1];
		//std::cout<<"tar: "<<pde.tar.transpose()<<std::endl;
	}

	void Run()
	{
		pde_opt.Optimize();
		//pde_opt.Optimize_GD_Without_Ipopt();
	}
};

#endif