from enum import Enum
import itertools
from . import utils
from pysdd.sdd import SddManager, Vtree
from collections import defaultdict


def implies(x,y):
	return ~x | y

def iff(x, y):
	return implies(x, y) & implies(y, x)

def construct_vars(n_lay, n_mat):
	'''
	Construct the variables needed for the formulas:

	x_i_j = presence/absence of material j on layer i
	'''
	vtree = Vtree(var_count = n_lay * n_mat)
	manager = SddManager.from_vtree(vtree)

	x = defaultdict(lambda: dict())
	c = 1
	for i in range(n_lay):
		for j in range(n_mat):
			x[i][j] = manager.literal(c)
			c += 1

	return x, manager, vtree


def construct_one_hot_formula(x, n_lay, n_mat):
	one_hot_formula = x[0][0] | ~x[0][0]
	for layer in range(n_lay):
		for mat in range(n_mat):
			for mat2 in range(n_mat):
				if mat != mat2:
					one_hot_formula = one_hot_formula & implies(x[layer][mat], ~x[layer][mat2])

	return one_hot_formula

def construct_at_least_one_per_layer_formula(x, n_lay, n_mat):
	at_least_one = x[0][0] | ~x[0][0]
	for layer in range(n_lay):
		big_or = x[layer][0] | x[layer][1]
		for mat in range(2, n_mat):
			big_or = big_or | x[layer][mat]

		at_least_one = at_least_one & big_or

	return at_least_one


def prevent_same_following_materials(x, n_lay, n_mat):
	'''
	Construct a formula that prevent two subsequent materials to be equal. ROW = layers, COL = material:

	AND [ x(i, j) -> ! x(i+1, j) ]
	'''
	formula = implies(x[0][0], ~x[0][0])
	for i in range(n_lay - 1):
		for j in range(n_mat):
			formula = formula & implies(x[i][j], ~(x[i+1][j]))

	one_hot_formula = construct_one_hot_formula(x, n_lay, n_mat)
	at_least_one = construct_at_least_one_per_layer_formula(x, n_lay, n_mat)
		
	return formula & one_hot_formula & at_least_one

def prevent_same_global_material(x, n_lay, n_mat):
	'''
	Construct a formula that prevent the presence of two same material. 

	For each column j (material)

	AND [x(i1, j) NAND x(i2, j)]
	'''

	formula = x[0][0] | ~x[0][0]

	# For each column, only one row must be set to 1
	for material in range(n_mat):
		for lay1, lay2 in itertools.combinations(range(n_lay), 2):
			formula = formula & (~x[lay1][material] | ~x[lay2][material])

	one_hot_formula = construct_one_hot_formula(x, n_lay, n_mat)
	at_least_one = construct_at_least_one_per_layer_formula(x, n_lay, n_mat)

	return formula & one_hot_formula & at_least_one


def force_palindrome_material(x, n_lay, n_mat, up_to=2):
	formula = x[0][0] | ~x[0][0]
	for material in range(n_mat):
		for lay in range(up_to):
			formula = formula & iff(x[lay][material], x[n_lay - 1 - lay][material])
		
	one_hot_formula = construct_one_hot_formula(x, n_lay, n_mat)
	at_least_one = construct_at_least_one_per_layer_formula(x, n_lay, n_mat)

	return formula & one_hot_formula & at_least_one


def force_use_all_materials(x, n_lay, n_mat):
	formula = x[0][0] | ~x[0][0]

	for material in range(n_mat):
		big_or = x[0][material] | x[1][material]
		for layer in range(2, n_lay):
			big_or = big_or | x[layer][material]

		formula = formula & big_or

	one_hot_formula = construct_one_hot_formula(x, n_lay, n_mat)
	at_least_one = construct_at_least_one_per_layer_formula(x, n_lay, n_mat)

	return formula & one_hot_formula & at_least_one


def force_hyperbolic_material(x, n_lay, n_mat, pattern_len):
	formula = x[0][0] | ~x[0][0]

	for layer in range(n_lay - pattern_len):
		for material in range(n_mat):
			# If a material is present at layer i, then it must be present also in layer i + step
			formula = formula & iff(x[layer][material], x[layer + pattern_len][material])

			# If a material is present at layer i, then it must NOT be present at layer i + 1, .., i + step - 1
			for step in range(1, pattern_len):
				formula = formula & ~(x[layer][material] & x[layer + step][material])

	one_hot_formula = construct_one_hot_formula(x, n_lay, n_mat)
	at_least_one = construct_at_least_one_per_layer_formula(x, n_lay, n_mat)

	return formula & one_hot_formula & at_least_one



class SemanticExperiment(Enum):
	PALINDROME_2 = 1
	PALINDROME_3 = 2
	PALINDROME_4 = 3
	PERIODIC_2 = 4
	PERIODIC_3 = 5
	PERIODIC_4 = 6
	NO_ADJACENT = 7
	USE_ALL = 8

	def get_constraint_function(self, num_lay, num_mat):
		"""
		Returns a callable constraint function corresponding to this enum value.

		Given the semantic loss variables `x`, this function returns a callable `fun(x)` that, when invoked,
		constructs the constraint formula as well as the corresponding `.sdd` and `.vtree` structures.

		Args:
			x (defaultdict): The semantic loss variables used to encode the constraint.

		Returns:
			Callable: A function that takes `x` and returns the compiled constraint formula.

		Example:
			x, _ = construct_vars()
			phi = SemanticExperiment.PALINDROME_2
			formula = phi.get_constraint_function(num_lay, num_mat)(x)
		"""

		if self == SemanticExperiment.PALINDROME_2:
			return lambda x: force_palindrome_material(x, num_lay, num_mat, up_to=2)
		elif self == SemanticExperiment.PALINDROME_3:
			return lambda x: force_palindrome_material(x, num_lay, num_mat, up_to=3)
		elif self == SemanticExperiment.PALINDROME_4:
			return lambda x: force_palindrome_material(x, num_lay, num_mat, up_to=4)
		
		elif self == SemanticExperiment.PERIODIC_2:
			return lambda x: force_hyperbolic_material(x, num_lay, num_mat, pattern_len=2)
		elif self == SemanticExperiment.PERIODIC_3:
			return lambda x: force_hyperbolic_material(x, num_lay, num_mat, pattern_len=3)
		elif self == SemanticExperiment.PERIODIC_4:
			return lambda x: force_hyperbolic_material(x, num_lay, num_mat, pattern_len=4)
		
		elif self == SemanticExperiment.NO_ADJACENT:
			return lambda x: prevent_same_following_materials(x, num_lay, num_mat)
		
		elif self == SemanticExperiment.USE_ALL:
			return lambda x: force_use_all_materials(x, num_lay, num_mat)
		
		

	def get_log_filenames(self):
		"""
		Given a selected constraint, returns information useful for logging and visualization.

		Returns:
			tuple:
				str: A name string to be used for the log file.
				str: A title string to be used for the plot.
		"""

		if self == SemanticExperiment.PALINDROME_2:
			return "Palindrome_2", "Palindrome materials 2"
		elif self == SemanticExperiment.PALINDROME_3:
			return "Palindrome_3", "Palindrome materials 3"
		elif self == SemanticExperiment.PALINDROME_4:
			return "Palindrome_4", "Palindrome materials 4"
		
		elif self == SemanticExperiment.PERIODIC_2:
			return "Periodic_2", "Periodic materials 2"
		elif self == SemanticExperiment.PERIODIC_3:
			return "Periodic_3", "Periodic materials 3"
		elif self == SemanticExperiment.PERIODIC_4:
			return "Periodic_4", "Periodic materials 4"
		
		elif self == SemanticExperiment.NO_ADJACENT:
			return "No_adjacent", "No adjacent materials"
		
		elif self == SemanticExperiment.USE_ALL:
			return "Use_all", "Use all materials"
		
		return ""
		


	def get_count_function(self, num_lay, num_mat):
		"""
		Returns a callable count function, related to the selected enum value.

		Given the material tensor `x`, containing reconstructed materials, this function returns a callable `fun(x)` that, when invoked,
		returns the count of materials that satisfies the given constraint.

		Args:
			x (torch.Tensor): The material tensor

		Returns:
			Callable: A function that takes `x` and returns the sat count, along with a mask to choose only the satisfiable materials form x.
		"""

		if self == SemanticExperiment.PALINDROME_2:
			return lambda x: utils.count_number_palindrome_materials(x, num_lay, num_mat, up_to=2)
		elif self == SemanticExperiment.PALINDROME_3:
			return lambda x: utils.count_number_palindrome_materials(x, num_lay, num_mat, up_to=3)
		elif self == SemanticExperiment.PALINDROME_4:
			return lambda x: utils.count_number_palindrome_materials(x, num_lay, num_mat, up_to=4)
		
		elif self == SemanticExperiment.PERIODIC_2:
			return lambda x: utils.count_number_hyperbolic_materials(x, num_lay, num_mat, pattern_len=2)
		elif self == SemanticExperiment.PERIODIC_3:
			return lambda x: utils.count_number_hyperbolic_materials(x, num_lay, num_mat, pattern_len=3)
		elif self == SemanticExperiment.PERIODIC_4:
			return lambda x: utils.count_number_hyperbolic_materials(x, num_lay, num_mat, pattern_len=4)
		
		elif self == SemanticExperiment.NO_ADJACENT:
			return lambda x: utils.count_no_consecutive_repeated_layers(x, num_lay, num_mat)
		
		elif self == SemanticExperiment.USE_ALL:
			return lambda x: utils.count_number_metamat_use_all_materials(x, num_lay, num_mat)

		
