import numpy as np
from numpy.random import multivariate_normal

class DiscreteFactor():
	'''
	This class creates a discrete factor.
	'''
	
	def __init__(self, var, card, pmf):
		'''
		This factor requires three inputs:
		- var: A numerical id that identifies the random variable. Ex: var=1
		- card: Cardinality of the random variable --i.e. number of values that it can take.
			Ex: card=2 (Bernoulli)
		- pmf: Probability mass function of the random variable. All the entries in this array
			need to be greater than zero and sum to one.
			Ex: [0.3, 0.7] represents P(y=0) = 0.3, P(y=1) = 0.7
		'''
		self.var = var
		self.card = card
		self.pmf = pmf
	
	def sample(self, num_instances):
		'''
		This function outputs num_instances instances from a Bernoulli
		distribution with parameter theta
		'''
		instances = np.random.choice(self.card, size=num_instances, p=self.pmf)
		
		return instances

class GaussianGivenDiscrete():
	'''
	This class creates a factor that follows a conditional Gaussian Distribution.
	We have one Gaussian per each combination of parents
	'''
	def __init__(self, parents_factors, mu_vector, sigma_vector):
		'''
		Assign a mu and sigma per each possible combination of parents factors. 
		The first factor in the list is the one whose value changes more slowly in the table.
		Example:
		-----------------------------------------
		|Par_1|Par_2|Par_3| Mu_indx | Sigma_indx|
		   0     0     0       0           0
		   0     0     1       1           1
		   0     0     2       2           2
		   0     1     0       3           3
		   0     1     1       4           4
		   0     1     2       5           5
		   1     0     0       6           6
		   1     0     1       7           7
		   1     0     2       8           8
		   1     1     0       9           9
		   1     1     1      10          10
		   1     1     2      11          11
		-----------------------------------------
		'''
		self.parents_factors = parents_factors
		self.mu_vector = mu_vector
		self.sigma_vector = sigma_vector
		
		# Create a dictionary that will map assignments to indexes
		# and indexes to assignments.
		self.var_cardinality = list()
		for factor in parents_factors:
			self.var_cardinality.append(factor.card)
		
		# Compute the number of entries that the factor should have:
		num_entries = np.int32(np.prod(self.var_cardinality))

		# Create the array that converts assignment to index
		temp = np.hstack([self.var_cardinality, 1])
		temp = np.flip(temp, axis=0)
		temp = np.flip(np.cumprod(temp), axis=0)

		self.convert_a_to_i = temp
	
	def assignment_to_indx(self, assignment):
		'''
		This function maps an assignment to each of the parents
		to the index of the mu, and sigma vectors
		'''
		# Function that returns the index (in the values vector) of the given assignment.
		# Assignment is an array with len(self.variables) entries
		temp_card = np.concatenate([self.var_cardinality[1:], [1]])
		index = np.sum(temp_card*assignment, axis=1)

		return np.reshape(np.int(index), (-1,1))

	def sample(self, assignment, num_instances):
		# Get the parameters that correspond to this assignment
		indx = self.assignment_to_indx(assignment)[0][0]
		c_mu = self.mu_vector[indx]
		c_var = self.sigma_vector[indx]
		instances = np.random.multivariate_normal(c_mu, c_var, num_instances)
		
		return instances

class SoftmaxGivenParentsFactor():
	'''
	This factor outputs a distribution given by softmax of a linear combination of a set of weights
	and continuous features.
	This factors models a policy given the context (C, D)
	'''
	def __init__(self, var, card, theta, cParents, dParents):
		# The number of rows of theta should be equal 
		# to the multiplication of the cardinality of dParents
		# The first parent is the one whose value changes more slowly
		self.var = var
		self.card = card
		self.theta = theta
		
		# Create a dictionary that will map assignments to indexes
		# and indexes to assignments.
		self.d_parent_cardinality = list()
		for factor in dParents:
			self.d_parent_cardinality.append(factor.card)
		
		# Compute the number of entries that the factor should have:
		num_entries = np.int32(np.prod(self.d_parent_cardinality))

		# Create the array that converts assignment to index
		temp = np.hstack([self.d_parent_cardinality, 1])
		temp = np.flip(temp, axis=0)
		temp = np.flip(np.cumprod(temp), axis=0)

		self.convert_a_to_i = temp
		
	def assignment_to_indx(self, assignment):
		'''
		This function maps an assignment to each of the parents
		to the index of the mu, and sigma vectors
		'''
		# Function that returns the index (in the values vector) of the given assignment.
		# Assignment is an array with len(self.variables) entries
		temp_card = np.concatenate([self.d_parent_cardinality[1:], [1]])
		index = np.sum(temp_card*assignment, axis=1)

		return np.reshape(np.int(index), (-1,1))
	
	def sample(self, dAssignment, cAssignment, num_instances):
		# Get the parameters that correspond to this assignment
		indx = self.assignment_to_indx(dAssignment)[0][0]
		
		# Theta is a matrix of (num_features+1) x num_outputs
		c_theta = self.theta[indx]

		# The first element of theta is the bias
		b = c_theta[0:1,:]

		# Take the dot product of the assignment and the weights
		logits = np.matmul(cAssignment, c_theta[1:,:]) + b
		pmf = np.exp(logits)/np.sum(np.exp(logits))

		# Sample from the pmf
		instances = np.random.choice(self.card, size=num_instances, p=pmf[0])
		
		return instances, pmf[0]

def generate_dataset(num_instances, D, C_given_D, A, R):
	# Use MonteCarlo simulation to generate new instances
	dataset = list()
	
	for i in range(num_instances):
		# Generate a value for D
		d = D.sample(1)[0]
		
		# Generate a value for C_given_D
		c = np.round(C_given_D.sample(np.array([[d]]), 1)[0][0], 3)
		
		# Generate a value for A given D, C
		# Also generate the propensity score
		a, pmf_a = A.sample([[d]],np.array([[c]]),1)
		a = a[0]
		p = pmf_a[a]
		
		# Generate the Reward
		r, pmf_r = R.sample([[d,a]],np.array([[c]]),1)
		r = r[0]
		
		instance = np.array([d,c,a,p,r])
		dataset.append(instance)
	
	return np.array(dataset)
	
def main():
	return -1

if __name__ == '__main__':
	main()
