/**
File:		Electricity_Traffic.cpp

Author:		
Email:		
Site:       

Copyright (c) 2022 . All rights reserved.
*/

#include <NeCore.h>
#include <NeMachineLearning.h>

#include <iostream>
#include <filesystem>

using namespace NeuralEngine;
using namespace NeuralEngine::MachineLearning;
using namespace NeuralEngine::MachineLearning::GPModels;
using namespace af;

namespace fs = std::filesystem;

template<typename Scalar>
void RollingCV(af::array& Y, int windowSize, int nrWindows, std::string mode, std::string modelName, std::string path)
{
	dtype m_dType = CommonUtil<Scalar>::CheckDType();

	if (modelName == "Traffic") Y = Y(af::span, seq(0, 440 - 1));

	int T = Y.dims(0);	// length of the dataset
	int D = Y.dims(1);  // dimension of time series

	int iq = 10, batchSize = 24, numIter = 350;
	Scalar priorMean = 0.0, priorVariance = 1000.0, dAlpha;

	if (modelName == "Traffic")  priorVariance = 1.0;

	std::vector<Scalar> valpha{ 0.1, 0.25, 0.5, 0.75, 0.9 };

	int hk = 200, hD = 100, numSamples = 10;;
	HiddenLayerDescription description(hk, hD);

	Scalar ik = 100;
	std::vector<HiddenLayerDescription> descriptions;
	descriptions.push_back(HiddenLayerDescription(ik, 100));
	descriptions.push_back(HiddenLayerDescription(ik, 100));

	af::array trueY = Y(af::seq(T - (nrWindows * windowSize), af::end), af::span);

	af::array param;

	af::array afMean = af::mean(Y);
	af::array afStd = af::stdev(Y);
	af::array a = 1.0 / afStd;
	af::array b = a * afMean;

	a(af::isNaN(a)) = 0.0;
	b(af::isNaN(b)) = 0.0;

	af::array my, vy, mx, vx;
	//for (auto j = 0; j < valpha.size(); j++)
	{
		dAlpha = valpha[2];
		std::vector<af::array> forecastMY, forecastVY;
		Timer timer;
		AEP::SDGPSSM<Scalar>* model = nullptr;
		trueY = (trueY - af::tile(afMean, trueY.dims(0), 1)) / af::tile(afStd, trueY.dims(0), 1);
		trueY(af::isNaN(trueY)) = 0.0;
		int trnStart = 0, trnEnd = 0;
		for (int i = 0; i < nrWindows; i++)
		{
			trnStart = 0; trnEnd = T - (nrWindows - i) * windowSize;

			af::array trnIdx = af::seq(trnStart, trnEnd - 1);
			af::array YTrain = Y(trnIdx, af::span);

			YTrain = (YTrain - af::tile(afMean, YTrain.dims(0), 1)) / af::tile(afStd, YTrain.dims(0), 1);
			YTrain(af::isNaN(YTrain)) = 0.0;

			model = new AEP::SDGPSSM<Scalar>(YTrain, iq, description, dAlpha, priorMean, priorVariance, af::array(), PropagationMode::MonteCarlo);

			model->Optimise(OptimizerType::ADAM, 0.0, false, numIter, batchSize, LineSearchType::MoreThuente);

			model->PredictForward(windowSize, my, vy, numSamples);

			for (auto ns = 0; ns < numSamples; ns++)
			{
				//my(af::span, af::span, ns) = (my(af::span, af::span, ns) - af::tile(b, my.dims(0), 1)) / af::tile(a, my.dims(0), 1);
				//vy = vy * af::tile(afStd, my.dims(0), 1) + af::tile(afMean, my.dims(0), 1);

				if (i == 0)
				{
					forecastMY.push_back(my(af::span, af::span, ns));
					forecastVY.push_back(vy(af::span, af::span, ns));
				}
				else
				{
					forecastMY[ns] = CommonUtil<Scalar>::Join(forecastMY[ns], my(af::span, af::span, ns));
					forecastVY[ns] = CommonUtil<Scalar>::Join(forecastVY[ns], vy(af::span, af::span, ns));
				}
				
			}
		}

		fs::create_directory(path + "\\resources\\AAAI23");
		fs::create_directory(path + "\\resources\\AAAI23\\" + modelName);
		fs::create_directory(path + "\\resources\\AAAI23\\" + modelName + "\\" + mode);
		fs::create_directory(path + "\\resources\\AAAI23\\" + modelName + "\\" + mode + "\\Alpha_" + std::to_string(dAlpha));

		std::string savePath = path + "\\resources\\AAAI23\\" + modelName + "\\" + mode + "\\Alpha_" + std::to_string(dAlpha) + "\\";
		for (auto ns = 0; ns < numSamples; ns++)
		{
			CommonUtil<Scalar>::WriteTXT(forecastMY[ns], savePath + "forecastMY_Sample_" + std::to_string(ns) + ".txt");
			CommonUtil<Scalar>::WriteTXT(forecastVY[ns], savePath + "forecastVY_Sample_" + std::to_string(ns) + ".txt");
		}
		
		CommonUtil<Scalar>::WriteTXT(trueY, savePath + "trueY.txt");
		SaveModel<AEP::SDGPSSM<Scalar>>(savePath + "SDGPSSM_time_" + std::to_string(timer.GetSeconds()) + "sec.dat", model);

		forecastMY.clear();
		forecastVY.clear();

		delete model;
	}
}

////////////////////////////////////////////////////////////////////////////////////////////////////
 /// <summary>	
 /// 	AEP::DeepGPSSM to learn electricity dataset. 
 /// </summary>
 ///
 /// <remarks>	Hmetal T, 01/06/2018. </remarks>
 ///
 /// <param name="path">	Full pathname of the file. </param>
 ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Scalar>
void ElectricityTraffic_AEP_DGPSSM_Stocatic(std::string path)
{
	af::dtype m_dType = CommonUtil<Scalar>::CheckDType();

	std::string mode = "GPU";
	std::string modelName = "Electricity";

	int windowSize = 24, nrWindows = 7;

	af::array Y, trueY;
	af::array forecastMY, forecastVY;

	std::cout << "Loading Electricity Dataset...\n";
	Y = CommonUtil<Scalar>::ReadTXT(path + "\\data\\electricity.txt");
	std::cout << "Done.\n\n";
	RollingCV<Scalar>(Y, windowSize, nrWindows, mode, modelName, path);

	modelName = "Traffic";
	std::cout << "Loading Traffic Dataset...\n";
	Y = CommonUtil<Scalar>::ReadTXT(path + "\\data\\traffic.txt");
	std::cout << "Done.\n\n";
	RollingCV<Scalar>(Y, windowSize, nrWindows, mode, modelName, path);

	system("pause");
}

template<typename Scalar>
void ComputeMetrics(std::string path)
{
	af::dtype m_dType = CommonUtil<Scalar>::CheckDType();

	std::string mode = "GPU";
	std::string modelName = "Electricity";
	int numSamples = 10;
	Scalar dAlpha = 0.5;

	af::array NRMSE = af::constant(0.0, numSamples, m_dType);
	af::array ND = af::constant(0.0, numSamples, m_dType);
	af::array MASE = af::constant(0.0, numSamples, m_dType);
	af::array trueY;

	std::string loadPath = path + "\\resources\\AAAI23\\" + modelName + "\\" + mode + "\\Alpha_" + std::to_string(dAlpha) + "\\";
	trueY = CommonUtil<Scalar>::ReadTXT(loadPath + "trueY.txt");
	for (auto ns = 0; ns < numSamples; ns++)
	{
		af::array forecastMY = CommonUtil<Scalar>::ReadTXT(loadPath + "forecastMY_Sample_" + std::to_string(ns) + ".txt");
		//af::array forecastVY = CommonUtil<Scalar>::ReadTXT(loadPath + "forecastVY_Sample_" + std::to_string(ns) + ".txt");

		NRMSE(ns) = Metrics<Scalar>::NRMSE(trueY, forecastMY);
		ND(ns) = Metrics<Scalar>::ND(trueY, forecastMY);
		MASE(ns) = Metrics<Scalar>::MASE(trueY, forecastMY);
	}

	std::cout << modelName + " NRMSE: " << af::mean<Scalar>(NRMSE) << std::endl;
	std::cout << modelName + " ND: " << af::mean<Scalar>(ND) << std::endl;
	std::cout << modelName + " MASE: " << af::mean<Scalar>(MASE) << std::endl << std::endl;

	modelName = "Traffic";
	loadPath = path + "\\resources\\AAAI23\\" + modelName + "\\" + mode + "\\Alpha_" + std::to_string(dAlpha) + "\\";

	trueY = CommonUtil<Scalar>::ReadTXT(loadPath + "trueY.txt");
	for (auto ns = 0; ns < numSamples; ns++)
	{
		af::array forecastMY = CommonUtil<Scalar>::ReadTXT(loadPath + "forecastMY_Sample_" + std::to_string(ns) + ".txt");
		//af::array forecastVY = CommonUtil<Scalar>::ReadTXT(loadPath + "forecastVY_Sample_" + std::to_string(ns) + ".txt");

		NRMSE(ns) = Metrics<Scalar>::NRMSE(trueY, forecastMY);
		ND(ns) = Metrics<Scalar>::ND(trueY, forecastMY);
		MASE(ns) = Metrics<Scalar>::MASE(trueY, forecastMY);
	}

	std::cout << modelName + " NRMSE: " << af::mean<Scalar>(NRMSE) << std::endl;
	std::cout << modelName + " ND: " << af::mean<Scalar>(ND) << std::endl;
	std::cout << modelName + " MASE: " << af::mean<Scalar>(MASE) << std::endl << std::endl;

	system("pause");
}

 int main(int, char const*[])
 {
 #if defined(_DEBUG)
	LogReporter reporter(
		"LogReport.txt",
		Listener::LISTEN_FOR_ALL,
		Listener::LISTEN_FOR_ALL,
		Listener::LISTEN_FOR_ERROR,
		Listener::LISTEN_FOR_NOTHING);
 #endif

	Environment env;
	std::string nepath = env.GetVariable("NE_PATH");
	if (nepath == "")
	{
		LogError("You must create the environment variable NE_PATH.");
		return 0;
	}

	HWND consoleWindow = GetConsoleWindow();

	SetWindowPos(consoleWindow, 0, 10, 10, 0, 0, SWP_NOSIZE | SWP_NOZORDER);

	ElectricityTraffic_AEP_DGPSSM_Stocatic<double>(nepath);
	ComputeMetrics<double>(nepath);

	
	return 0;
}
