from itertools import product
from collections import OrderedDict

SHAPES = ['sphere', 'spl', 'cylinder', 'cube', 'cone']

ACTION_CLASSES = [
	# object, movement
	('sphere', '_slide'),
	('sphere', '_pick_place'),
	('spl', '_slide'),
	('spl', '_pick_place'),
	('spl', '_rotate'),
	('cylinder', '_pick_place'),
	('cylinder', '_slide'),
	('cylinder', '_rotate'),
	('cube', '_slide'),
	('cube', '_pick_place'),
	('cube', '_rotate'),
	('cone', '_contain'),
	('cone', '_pick_place'),
	('cone', '_slide'),
]
_BEFORE = 'before'
_AFTER = 'after'
_DURING = 'during'
ORDERING = [
	_BEFORE,
	_DURING,
	_AFTER,
]
SIZES = ['small', 'medium', 'large']
inv_ord = {s: o for s, o in zip(ORDERING, ORDERING[::-1])}

CLASSES_MAP = {ac: i for i, ac in enumerate(ACTION_CLASSES)}

SLIDE = '_slide'
CONTAIN = '_contain'
PICK_PLACE = '_pick_place'
ROTATE = '_rotate'
NO_OP = '_no_op'

MAX_FRAMES = 301

def enumerate_composite_actions():
	composite_actions = {}
	for action_sub in ACTION_CLASSES:
		for rel in ORDERING:
			for action_obj in ACTION_CLASSES:
				comp_action = (action_sub, rel, action_obj)
				reverse_action = (action_obj, inv_ord[rel], action_sub)
				# set it to the reverse action else create a new class
				composite_actions[comp_action] = composite_actions.get(reverse_action, len(set(composite_actions.values())))
	return composite_actions

COMP_ACTIONS = enumerate_composite_actions()

#composite actions from the CATER code
# def reverse(el):
# 	if el == ('during',):
# 		return el
# 	elif el == ('before',):
# 		return ('after',)
# 	elif el == ('after',):
# 		return ('before',)
# 	else:
# 		raise ValueError('This should not happen')

def reverse_rel(el):
	if el == 'during':
		return el
	elif el == 'before':
		return 'after'
	elif el == 'after':
		return 'before'
	else:
		raise ValueError('This should not happen')

def reverse(el):
	rev_actions = el[0][::-1]
	rev_rel = tuple(reverse_rel(r) for r in el[1])
	rev_class = (rev_actions, rev_rel)
	return rev_class

def action_order_unique(classes, as_list=True):
	if as_list:
		classes_uniq = []
	else:
		classes_uniq = OrderedDict()
		class_num = 0
	for el in classes:
		rev_class = reverse(el)
		if el not in classes_uniq and rev_class not in classes_uniq:
			if as_list:
				classes_uniq.append(el)
			else:
				classes_uniq.update({el: class_num})
				class_num += 1
	return classes_uniq

def actions_order_dataset(atomic_events=ACTION_CLASSES, n=2, unique=True, as_list=True, samples=None):
	action_sets = list(product(atomic_events, repeat=n))
	# all orderings
	orderings = list(product(ORDERING, repeat=(n-1)))
	# all actions and orderings
	classes = list(product(action_sets, orderings))
	if unique:
		# Remove classes such as "X before Y" when "Y after X" already exists in the data
		classes = action_order_unique(classes, as_list=as_list)
	# print('Action orders classes {}'.format(len(classes)))
	if as_list:
		return classes[::samples]
	else:
		classes = OrderedDict({k: v for k, v in list(classes.items())[::samples]})
		for index, k in enumerate(classes.keys()):
			classes[k] = index
		return classes

# def consistent_action(comp_action):
# 	return ((comp_action[0][1], comp_action[0][0]), reverse(comp_action[1]))

def actions_order_mapping(n=2):

	comp_actions = actions_order_dataset(unique=False)
	uniq_comp_actions = actions_order_dataset(unique=True)

	indices = []
	for idx, u_comp_action in enumerate(uniq_comp_actions):
		class_idx = comp_actions.index(u_comp_action)
		# cons_comp_action = consistent_action(u_comp_action)
		cons_comp_action = reverse(u_comp_action)
		other_class_idx = comp_actions.index(cons_comp_action)
		indices.append((class_idx, other_class_idx))
	return indices

def check_rule_match(gt_rule, pred_rule):
		if len(gt_rule) != len(pred_rule):
			return False
		pred_predicates = set()
		for pred_predicate in pred_rule:
			pred_predicates.add(pred_predicate)
			pred_predicates.add(reverse(pred_predicate))
		for gt_predicate in gt_rule:
			if gt_predicate not in pred_predicates:
				return False
		return True

if __name__ == '__main__':
	classes = actions_order_dataset(unique=False)
	uniq_classes = actions_order_dataset(unique=True)
	pass