# factor_graph.py
# Author: Jacob Schreiber <jmschreiber91@gmail.com>

import time
import torch
import itertools
from tqdm import tqdm
import numpy as np

from itertools import product

from ._utils import _cast_as_tensor
from ._utils import _update_parameter
from ._utils import _check_parameter
from ._utils import _reshape_weights

from .distributions._distribution import Distribution
from .distributions._distribution import ConditionalDistribution

from .distributions import NeuralDistribution
from .distributions import EgoNeuralDistribution
from .distributions import Categorical
from .distributions import JointCategorical


class FactorGraph(Distribution):
	"""A factor graph object.

	A factor graph represents a probability distribution as a bipartite graph
	where marginal distributions of each dimension in the distribution are on
	one side of the graph and factors are on the other side. The distributions
	on the factor side encode probability estimates from the model, whereas the
	distributions on the marginal side encode probability estimates from the
	data. 

	Inference is done on the factor graph using the loopy belief propagation
	algorithm. This is an iterative algorithm where "messages" are passed
	along each edge between the marginals and the factors until the estimates
	for the marginals converges. In brief: each message represents what the
	generating node thinks its marginal distribution is with respect to the
	child node. Calculating each message involves marginalizing the parent
	node with respect to *every other* node. When the parent node is already
	a univariate distribution -- either because it is a marginal node or
	a univariate factor node -- no marginalization is needed and it sends
	itself as the message. Basically, a joint probability table will receive
	messages from all the marginal nodes that comprise its dimensions and,
	to each of those marginal nodes, it will send a message back saying what
	it (the joint probability table) thinks its marginal distribution is with
	respect to the messages from the OTHER marginals. More concretely, if a
	joint probability table has two dimensions with marginal node parents
	A and B, it will send a message to A that is itself after marginalizing out
	B, and will send a message to B that is itself after marginalizing out A. 

	..note:: It is worth noting that this algorithm is exact when the structure
	is a tree. If there exist any loops in the factors, i.e., you can draw a
	circle beginning with a factor and then hopping between marginals and
	factors and make it back to the factor without crossing any edges twice,
	the probabilities returned are approximate.


	Parameters
	----------
	factors: tuple or list or None
		A set of distribution objects. These do not need to be initialized,
		i.e. can be "Categorical()". Currently, they must be either Categorical
		or JointCategorical distributions. Default is None.

	marginals: tuple or list or None
		A set of distribution objects. These must be initialized and be
		Categorical distributions.

	edges: list or tuple or None
		A set of edges. Critically, the items in this list must be the
		distribution objects themselves, and the order that edges must match
		the order distributions in a multivariate distribution. Specifically,
		if you have a multivariate distribution, the first edge that includes
		it must correspond to the first dimension, the second edge must
		correspond to the second dimension, etc, and the total number of
		edges cannot exceed the number of dimensions. Default is None.

	max_iter: int, optional
		The number of iterations to do in the inference step as distributions
		are converging. Default is 10.

	tol: float, optional
		The threshold at which to stop during fitting when the improvement
		goes under. Default is 1e-6.

	inertia: float, [0, 1], optional
		Indicates the proportion of the update to apply to the parameters
		during training. When the inertia is 0.0, the update is applied in
		its entirety and the previous parameters are ignored. When the
		inertia is 1.0, the update is entirely ignored and the previous
		parameters are kept, equivalently to if the parameters were frozen.

	frozen: bool, optional
		Whether all the parameters associated with this distribution are frozen.
		If you want to freeze individual pameters, or individual values in those
		parameters, you must modify the `frozen` attribute of the tensor or
		parameter directly. Default is False.

	check_data: bool, optional
		Whether to check properties of the data and potentially recast it to
		torch.tensors. This does not prevent checking of parameters but can
		slightly speed up computation when you know that your inputs are valid.
		Setting this to False is also necessary for compiling.

	verbose: bool, optional
		Whether to print the improvement and timings during training.
	"""

	def __init__(self, factors=None, marginals=None, edges=None, max_iter=20, 
		tol=1e-6, inertia=0.0, frozen=False, check_data=True, verbose=False):
		super().__init__(inertia=inertia, frozen=frozen, check_data=check_data)
		self.name = "FactorGraph"

		self.factors = torch.nn.ModuleList([])
		self.marginals = torch.nn.ModuleList([])

		self._factor_idxs = {}
		self._marginal_idxs = {}

		self._factor_edges = []
		self._marginal_edges = []

		self.max_iter = _check_parameter(max_iter, "max_iter", min_value=1, 
			dtypes=[int, torch.int16, torch.int32, torch.int64])
		self.tol = _check_parameter(tol, "tol", min_value=0)
		self.verbose = verbose

		self.d = 0
		self._initialized = factors is not None and factors[0]._initialized

		if factors is not None:
			_check_parameter(factors, "factors", dtypes=(list, tuple))

			for factor in factors:
				self.add_factor(factor)

		if marginals is not None:
			_check_parameter(marginals, "marginals", dtypes=(list, tuple))

			for marginal in marginals:
				self.add_marginal(marginal)

		if edges is not None:
			_check_parameter(edges, "edges", dtypes=(list, tuple))

			for marginal, factor in edges:
				self.add_edge(marginal, factor)

		self._initialized = not factors

		self.trained = False

	def _initialize(self, d):
		self._initialized = True
		super()._initialize(d)

	def _reset_cache(self):
		return

	def add_factor(self, distribution):
		"""Adds a distribution to the set of factors.

		
		Parameters
		----------
		distribution: pomegranate.distributions.Distribution
			A distribution object to include as a node.
		"""

		if not isinstance(distribution, (Categorical, JointCategorical, NeuralDistribution, EgoNeuralDistribution)):
			raise ValueError("Must be a Categorical or a JointCategorical or NeuralDistribution"
				" distribution.")

		self.factors.append(distribution)
		self._factor_edges.append([])
		self._factor_idxs[distribution] = len(self.factors) - 1

		self._initialized = distribution._initialized

	def add_marginal(self, distribution):
		"""Adds a distribution to the set of marginals.

		This adds a distribution to the marginal side of the bipartate graph.
		This distribution must be univariate. 

		Parameters
		----------
		distribution: pomegranate.distributions.Distribution
			A distribution object to include as a node.
		"""

		if not isinstance(distribution, Categorical):
			raise ValueError("Must be a Categorical distribution.")
		
		self.marginals.append(distribution)
		self._marginal_edges.append([])
		self._marginal_idxs[distribution] = len(self.marginals) - 1

		self.d += 1

	def add_edge(self, marginal, factor):
		"""Adds an undirected edge to the set of edges.

		Because a factor graph is a bipartite graph, one of the edges must be
		a marginal distribution and the other edge must be a factor 
		distribution.

		Parameters
		----------
		marginal: pomegranate.distributions.Distribution
			The marginal distribution to include in the edge.

		factor: pomegranate.distributions.Distribution
			The factor distribution to include in the edge.
		"""

		if marginal not in self._marginal_idxs:
			raise ValueError("Marginal distribution does not exist in graph.")
		
		if factor not in self._factor_idxs:
			raise ValueError("Factor distribution does not exist in graph.")

		m_idx = self._marginal_idxs[marginal]
		f_idx = self._factor_idxs[factor]

		if isinstance(factor, NeuralDistribution) or isinstance(factor, EgoNeuralDistribution):
			margina_index_in_factor = len(self._factor_edges[f_idx])
			number_of_categories = factor.categories[margina_index_in_factor]

			if number_of_categories != marginal.n_keys:
				raise ValueError("The number of categories in the factor and the marginal do not match.")

		self._factor_edges[f_idx].append(m_idx)
		self._marginal_edges[m_idx].append(f_idx)

	def log_probability(self, X):
		"""Calculate the log probability of each example.

		This method calculates the log probability of each example given the
		parameters of the distribution. The examples must be given in a 2D
		format.


		Parameters
		----------
		X: list, tuple, numpy.ndarray, torch.Tensor, shape=(-1, self.d)
			A set of examples to evaluate.

		Returns
		-------
		logp: torch.Tensor, shape=(-1,)
			The log probability of each example.
		"""

		X = _check_parameter(_cast_as_tensor(X), "X", ndim=2, 
			check_parameter=self.check_data)
		
		logps = torch.zeros(X.shape[0], device=X.device, dtype=torch.float32)
		for idxs, factor in zip(self._factor_edges, self.factors):
			logps += factor.log_probability(X[:, idxs])

		for i, marginal in enumerate(self.marginals):
			logps += marginal.log_probability(X[:, i:i+1])

		return logps

	def predict(self, X):
		"""Infers the maximum likelihood value for each missing value.

		This method infers a probability distribution for each of the missing
		values in the data. First, the sum-product algorithm is run to infer
		a probability distribution for each variable. Then, the maximum
		likelihood value is returned from that distribution.

		The input to this method must be a torch.masked.MaskedTensor where the
		mask specifies which variables are observed (mask = True) and which ones
		are not observed (mask = False) for each of the values. When setting
		mask = False, it does not matter what the corresponding value in the
		tensor is. Different sets of variables can be observed or missing in
		different examples. 

		Unlike the `predict_proba` and `predict_log_proba` methods, this
		method preserves the dimensions of the original data because it does
		not matter how many categories a variable can take when you're only
		returning the maximally likely one.


		Parameters
		----------
		X: torch.masked.MaskedTensor
			A masked tensor where the observed values are available and the
			unobserved values are missing, i.e., the mask is True for
			observed values and the mask is False for missing values. It does
			not matter what the underlying value in the tensor is for the 
			missing values.
		"""

		y = [t.argmax(dim=1) for t in self.predict_proba(X)]
		return torch.vstack(y).T.contiguous()

	def predict_proba(self, X, alg="max"):
		"""Predict the probability of each variable given some evidence.

		Given some evidence about the value that each variable takes, infer
		the value that each other variable takes. If no evidence is given,
		this returns the marginal value of each variable given the dependence
		structure in the network.

		Currently, only hard evidence is supported, where the evidence takes
		the form of a discrete value. The evidence is represented as a
		masked tensor where the masked out values are considered missing.


		Parameters
		----------
		X: torch.masked.MaskedTensor
			A masked tensor where the observed values are available and the
			unobserved values are missing, i.e., the mask is True for
			observed values and the mask is False for missing values. It does
			not matter what the underlying value in the tensor is for the 
			missing values.
		"""

		nm = len(self.marginals)
		nf = len(self.factors)

		if X.shape[1] != nm:
			raise ValueError("X.shape[1] must match the number of marginals.")

		# Pull out the current distributions from factors and marginals
		factors = []
		marginals = []
		prior_marginals = []
		current_marginals = []

		for i, m in enumerate(self.marginals):
			p = torch.clone(m.probs[0])
			p = p.repeat((X.shape[0],) + tuple(1 for _ in p.shape))

			# Use the evidence in the data to set marginal distributions
			for j in range(X.shape[0]):
				if X._masked_mask[j, i] == True:
					value = X._masked_data[j, i]

					p[j] = 0
					p[j, value] = 1.0

			marginals.append(p)
			prior_marginals.append(torch.clone(p))
			current_marginals.append(torch.clone(p))

		# this part of the code is copying the factors of the model based on the number of inputs that we have.
		# since in NN there will be only one input, I am ignoring this (kinda)
		for i, f in enumerate(self.factors):
			if isinstance(f, NeuralDistribution) or isinstance(f, EgoNeuralDistribution):
				mask = X._masked_mask[:, self._factor_edges[i]]
				data = X._masked_data[:, self._factor_edges[i]]
				# non_masked_data = data[mask]
				# factors.append(f.model(torch.stack([non_masked_data])).item())
				factors.append(f.model)
			elif not isinstance(f, Categorical):
				p = torch.clone(f.probs)
				# print(p)
				p = p.repeat((X.shape[0],) + tuple(1 for _ in p.shape))
				# p = p * 10
				factors.append(p)
				# print(p)
				# exit()
			else:
				p = torch.clone(f.probs[0])
				p = p.repeat((X.shape[0],) + tuple(1 for _ in p.shape))
				# p = p * 10
				factors.append(p)


		# Set the original in and out messages along the edges
		in_messages, out_messages = [], []

		for i, m in enumerate(marginals):
			k = len(self._marginal_edges[i])
			in_messages.append([])

			for j in range(k):
				in_messages[-1].append(m)

		# length of factors is the same as length of self.factors
		for i in range(len(self.factors)):
			k = len(self._factor_edges[i])
			out_messages.append([])
			
			for j in range(k):
				marginal_idx = self._factor_edges[i][j]
				d_j = marginals[marginal_idx]
				out_messages[-1].append(d_j)

		#  in_message is f->x and out_message is x->f

		# out_messages each entry is the messages sending out from a factor
		# each out message is ordered by the variables connected to the factor
		# out messages is ordered by the factors in the model
		# Begin iterations
		iteration = 0
		while iteration < self.max_iter:
			# Update the messages going into the nodes.
			for i in range(len(self.factors)):
				f = self.factors[i]
				if not (isinstance(f, NeuralDistribution) or isinstance(f, EgoNeuralDistribution)):
					f = factors[i]
					ni_edges = len(self._factor_edges[i])
					for k in range(ni_edges):
						# lets see what message does here :))
						message = torch.clone(f)
						shape = torch.ones(len(message.shape), dtype=torch.int32)
						shape[0] = X.shape[0]

						for l in range(ni_edges):
							if k == l:
								continue

							shape[l+1] = message.shape[l+1]	
							message *= out_messages[i][l].reshape(*shape)		# message is not actually message, it is the internalized calculation on the factor 

							if alg == "max":
								message = torch.max(message, dim=l+1, keepdims=True)[0]
							elif alg == "sum":
								message = torch.sum(message, dim=l+1, keepdims=True)

							shape[l+1] = 1

						else:
							message = message.squeeze()
							if len(message.shape) == 1:
								message = message.unsqueeze(0)

						j = self._factor_edges[i][k]
						for ik, parent in enumerate(self._marginal_edges[j]):
							if parent == i:
								dims = tuple(range(1, len(message.shape)))
								# print(message)
								in_messages[j][ik] = message / message.sum(		#normalizing
									dim=dims, keepdims=True)

								break

	
				else:
					ni_edges = len(self._factor_edges[i])
					# for every edge connected to the factor
					for k in range(ni_edges):
						 # Iterate over each possible value of the target variable
						message = {}
						for x_i_val in range(f.categories[k]):
							sum_over_values = 0
							max_over_values = 0  # this is for max-product algorithms
							# Iterate over all combinations of values for other variables in the factor
							other_vars = [var for var in range(ni_edges) if var != k]
							relevant_values = {
								var: [val for val, prob in enumerate(out_messages[i][var][0]) if prob > 0]
								for var in other_vars
							}
							# Generate only the combinations of non-zero incoming messages
							relevant_combinations = product(*[relevant_values[var] for var in other_vars])
							combs = [ i for i in relevant_combinations]
							iteration_helper = combs
							if self.verbose:
								iteration_helper = tqdm(combs)
							for other_vals in iteration_helper:
								# Construct a full variable assignment for the factor
								var_values = {k: x_i_val}
								var_values.update({var: val for var, val in zip(other_vars, other_vals)})

								# Evaluate the joint probability and add to the sum
								
								joint_prob = f.probability([var_values[var] for var in range(ni_edges)])

								# Multiply by the incoming messages from each of the other variables
								product_of_messages = joint_prob
								for var, val in var_values.items():
									if var != k:
										product_of_messages *= out_messages[i][var][0][var_values[var]]
								sum_over_values += product_of_messages
								max_over_values = max(max_over_values, product_of_messages)

							# Store the sum in the message for the current value of the target variable
							if alg == "sum":
								message[x_i_val] = sum_over_values
							elif alg == "max":
								message[x_i_val] = max_over_values
						m2 = torch.tensor([[message[i] for i in message.keys()]])

						j = self._factor_edges[i][k]

						for ik, parent in enumerate(self._marginal_edges[j]):
							if parent == i:
								dims = tuple(range(1, len(m2.shape)))

								in_messages[j][ik] = m2 / m2.sum(		# normalizing
									dim=dims, keepdims=True)
								break


			# Calculate the current estimates of the marginals
			loss = 0
			for i, m in enumerate(marginals):
				current_marginals[i] = torch.clone(m)

				for k in range(len(self._marginal_edges[i])):
					# print(current_marginals[i], in_messages[i][k])
					current_marginals[i] *= in_messages[i][k]


				dims = tuple(range(1, len(current_marginals[i].shape)))
				current_marginals[i] /= current_marginals[i].sum(dim=dims, 
					keepdims=True)

				loss += torch.nn.KLDivLoss(reduction="batchmean")(torch.log(
					current_marginals[i] + 1e-8), prior_marginals[i])
				

			if self.verbose:
				print(iteration, loss.item())

			if loss < self.tol:
				# print("CONVERGED")
				# if self.verbose:
				# 	print("c", out_messages)
				# 	print("x", in_messages)
				# 	print("breaking at ", iteration)
				break


			# Update the messages leaving based on the new marginals
			for i, m in enumerate(marginals):
				ni_edges = len(self._marginal_edges[i])

				for k in range(ni_edges):
					message = torch.clone(m)

					for l in range(ni_edges):
						if k == l:
							continue

						message *= in_messages[i][l]

					j = self._marginal_edges[i][k]
					for ik, parent in enumerate(self._factor_edges[j]):
						if parent == i:
							dims = tuple(range(1, len(message.shape)))
							out_messages[j][ik] = message / message.sum(
								dim=dims, keepdims=True)
							break

			prior_marginals = [torch.clone(d) for d in current_marginals]
			iteration += 1

			# print(current_marginals[:6])


		return current_marginals

	def predict_log_proba(self, X):
		"""Infers the probability of each category given the model and data.

		This method is a wrapper around the `predict_proba` method and simply
		takes the log of each returned tensor.

		This method infers a log probability distribution for each of the 
		missing  values in the data. It uses the factor graph representation of 
		the Bayesian network to run the sum-product/loopy belief propagation
		algorithm.

		The input to this method must be a torch.masked.MaskedTensor where the
		mask specifies which variables are observed (mask = True) and which ones
		are not observed (mask = False) for each of the values. When setting
		mask = False, it does not matter what the corresponding value in the
		tensor is. Different sets of variables can be observed or missing in
		different examples. 

		An important note is that, because each variable can have a different
		number of categories in the categorical setting, the return is a list
		of tensors where each element in that list is the marginal probability
		distribution for that variable. More concretely: the first element will
		be the distribution of values for the first variable across all
		examples. When the first variable has been provided as evidence, the
		distribution will be clamped to the value provided as evidence.

		..warning:: This inference is exact given a Bayesian network that has
		a tree-like structure, but is only approximate for other cases. When
		the network is acyclic, this procedure will converge, but if the graph
		contains cycles then there is no guarantee on convergence.


		Parameters
		----------
		X: torch.masked.MaskedTensor, shape=(-1, d)
			The data to predict values for. The mask should correspond to
			whether the variable is observed in the example. 
		

		Returns
		-------
		y: list of tensors, shape=(d,)
			A list of tensors where each tensor contains the distribution of
			values for that dimension.
		"""

		return [torch.log(t) for t in self.predict_proba(X)]

	def fit(self, X, X_valid=None, sample_weight=None, from_file=None):
		"""Fit the factors of the model to optionally weighted examples.

		This method will fit the provided factor distributions to the given
		data and their optional weights. It will not update the marginal
		distributions, as that information is already encoded in the joint
		probabilities.

		..note:: A structure must already be provided. Currently, structure
		learning of factor graphs is not supported.


		Parameters
		----------
		X: list, tuple, numpy.ndarray, torch.Tensor, shape=(-1, self.d)
			A set of examples to evaluate. 

		sample_weight: list, tuple, numpy.ndarray, torch.Tensor, optional
			A set of weights for the examples. This can be either of shape
			(-1, self.d) or a vector of shape (-1,). Default is ones.


		Returns
		-------
		self
		"""

		self.summarize(X, X_valid=X_valid, sample_weight=sample_weight, from_file=None)
		self.from_summaries()
		self.trained = True
		return self

	def summarize(self, X, X_valid=None, sample_weight=None, from_file=None):
		"""Extract the sufficient statistics from a batch of data.

		This method calculates the sufficient statistics from optionally
		weighted data and adds them to the stored cache for each distribution
		in the network. Sample weights can either be provided as one
		value per example or as a 2D matrix of weights for each feature in
		each example.


		Parameters
		----------
		X: list, tuple, numpy.ndarray, torch.Tensor, shape=(-1, len, self.d)
			A set of examples to summarize.

		sample_weight: list, tuple, numpy.ndarray, torch.Tensor, optional
			A set of weights for the examples. This can be either of shape
			(-1, self.d) or a vector of shape (-1,). Default is ones.


		Returns
		-------
		logp: torch.Tensor, shape=(-1,)
			The log probability of each example.
		"""

		if self.frozen:
			return

		X, sample_weight = super().summarize(X, sample_weight=sample_weight)
		X = _check_parameter(X, "X", ndim=2, check_parameter=self.check_data)

		for i, factor in enumerate(self.factors):
			if isinstance(factor, NeuralDistribution) or isinstance(factor, EgoNeuralDistribution):
				factor.summarize(X[:, self._factor_edges[i]], 
					sample_weight=sample_weight[:,i], from_file=None)
			factor.summarize(X[:, self._factor_edges[i]], 
				sample_weight=sample_weight[:,i])

	def from_summaries(self):
		if self.frozen:
			return

		for distribution in self.factors:
			distribution.from_summaries()
	
	def load_from_file(self, folder_path):
		if self.frozen:
			return

		for factor in self.factors:
			if isinstance(factor, NeuralDistribution) or isinstance(factor, EgoNeuralDistribution):
				factor.load_from_file(folder_path)
