/**
File:		MachineLearning/Graph/Node/FgDiscreteVariableNode.cpp

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgDiscreteVariableNode.h>
#include <MachineLearning/FgDiscreteMessage.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		DiscreteVariableNode::DiscreteVariableNode(std::string name, int numStates)
			: IVariableNode(name, numStates)
		{
		}

		DiscreteVariableNode::~DiscreteVariableNode()
		{
		}

		bool DiscreteVariableNode::IsSupported(MsgType type)
		{
			return (type == MsgType::eDiscreteMessage);
		}

		bool DiscreteVariableNode::IsLeafNode()
		{
			return 0;
		}

		af::array& DiscreteVariableNode::Marginal(MsgBox& neededMessages)
		{
			af::array logProd = af::constant(0.0, 1, 2);

			for (MsgBox::iterator it = neededMessages.begin(); it != neededMessages.end(); ++it)
			{
				std::string from = it->first;
				IMessage &msg = it->second;

				logProd += static_cast<DiscreteMessage&>(msg).GetLogValue();
			}

			logProd = logProd - af::max(logProd);
			logProd = af::exp(logProd);
			return logProd / af::sum(logProd);
		}

		IMessage DiscreteVariableNode::ComputeMessage(std::string toNodeName, MsgBox &neededMessages)
		{
			DiscreteMessage retMsg;
			af::array logProd = af::constant(0.0, 1, 2);

			for (MsgBox::iterator it = neededMessages.begin(); it != neededMessages.end(); ++it)
			{
				std::string from = it->first;
				IMessage &msg = it->second;

				// skipping the message itself
				if (from == toNodeName)
					continue;

				logProd += static_cast<DiscreteMessage&>(msg).GetLogValue();
			}

			retMsg.SetLogValue(logProd);
			return retMsg;
		}
	}
}