from temporal.predicates import slide, pick_place, contain
from vision.visionutils import build_label, build_unique_label, MAX_FRAMES
from datasets.cater.enums import * 
from collections import defaultdict, Counter, namedtuple
from numpy.random import uniform, randint
from scipy.spatial.distance import euclidean

import json
import logging
import numpy as np

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


def during(sub_start, sub_end, obj_start, obj_end, margin=0.1):
	# assuming start and end overlap is also during
	# intuition: https://stackoverflow.com/questions/325933/determine-whether-two-date-ranges-overlap
	# during = sub_start < obj_end and obj_start < sub_end
	# if during:
	intersection = min(sub_end, obj_end) - max(sub_start, obj_start)
	union = max(sub_end, obj_end) - min(sub_start, obj_start)
	return intersection > 0 and intersection/union > margin

def during_soft(sub_start, sub_end, obj_start, obj_end):
	# assuming start and end overlap is also during
	# intuition: https://stackoverflow.com/questions/325933/determine-whether-two-date-ranges-overlap
	# during = sub_start < obj_end and obj_start < sub_end
	# if during:
	intersection = min(sub_end, obj_end) - max(sub_start, obj_start)
	union = max(sub_end, obj_end) - min(sub_start, obj_start)
	return max(0, intersection/union)

def get_shape(obj_name):
	# we only need the shape to classify
	return obj_name.split()[3]

def get_size(obj_name):
	return obj_name.split()[0]

class RawPrediction:
	def __init__(self, prob=None, shape=None, subject=None, obj=None, action=None, start=None, end=None):
		self.prob = prob
		self._shape = shape
		self.subject =subject
		self.obj = obj
		self.action = action
		self.start = start
		self.end = end

	@property
	def shape(self):
		if self._shape is None:
			self._shape = get_shape(self.subject)
		return self._shape

	def __str__(self):
		return f'{(self.prob, (self.shape, self.action), self.start, self.end)}'

	__repr__ = __str__

def sigmoid(x):
  return 1 / (1 + np.exp(-x))


class StateSpaceClassifier:

	def __init__(self, state_space):
		self.state_space = state_space
		self.state_space_objects = self.state_space['objects']
		self.obj_names = list(self.state_space_objects.keys())
		self.predictions = []
		self.predictions_raw = []

	def map_to_classes(self):
		# int_classes = map(lambda ac: CLASSES_MAP[ac], self.predictions)
		int_classes = []
		# print(Counter(self.predictions))
		for prediction in self.predictions:
			if prediction[1] != NO_OP and prediction in CLASSES_MAP:
				int_classes.append(CLASSES_MAP[prediction])
		return set(int_classes)

	def add_contained_slides(self):
		# if object O contains O', and then O slides, then O' slides as well.
		# should techincally be defined as a formal rule as well, but it is more complicated
		self.predictions_raw = sorted(self.predictions_raw, key=lambda x: x.start)  # sort by start time
		for index, contain_prediction in enumerate(self.predictions_raw):
			if contain_prediction.action == CONTAIN:  # subject, obj, CONTAIN
				# while the container subject is sliding, so is the contained object
				for slide_prediction in self.predictions_raw[index + 1:]:
					if slide_prediction.subject == contain_prediction.subject:
						if slide_prediction.action == SLIDE:
							new_pred = RawPrediction(prob=slide_prediction.prob, subject=contain_prediction.obj, action=SLIDE, 
										start=slide_prediction.start, end=slide_prediction.end)
							# print(f'slide added')
							self.predictions_raw.append(new_pred)
						else:
							# if the container performs another action, ie pick-place then the contained object is no longer restricted
							break

	def format_predictions_raw(self):
		for prediction in self.predictions_raw:
			output_tuple = ((get_shape(prediction.subject), prediction.action), prediction.start, prediction.end)
			data = self.state_space_objects[prediction.subject]
			if 'movement_predictions' in data:
				data['movement_predictions'].append(output_tuple)
			else:
				data['movement_predictions'] = [output_tuple]

	def sample_gt_action(self, action, acceptance_rate=0.6):
		# sample the rotation's since we cannot predict those yet
		# TODO: just have these available when constructing the intial state space, not as an ad-hoc
		# addition
		gt_movements = StateSpaceGT(json.load(open(self.state_space['file_path'])))
		gt_movements.predict_all()
		movements = gt_movements.predictions_raw
		for movement in movements:
			shape, movement_action = movement[0]
			if movement_action == action:
				if uniform(size=1).item() < acceptance_rate:
					if acceptance_rate != 1.0:
						rand_times = randint(-4, high=4, size=2)
						sampled_movement = RawPrediction(prob=acceptance_rate, shape=shape, action=movement_action)
						sampled_movement.start = max(0, movement[1] + rand_times[0])
						sampled_movement.end = min(MAX_FRAMES - 1, movement[2] + rand_times[1])
					else:
						sampled_movement = RawPrediction(prob=1.0, shape=shape, action=movement_action, start=movement[1], end=movement[2])
					self.predictions_raw.append(sampled_movement)

	def predict_all(self, acceptance_rate=0.6):
		for subject, data in self.state_space_objects.items():
			for start, end in data['object_movements']:
				prediction = self.predict(subject, start, end)
				self.predictions_raw.append(prediction)

		self.add_contained_slides()
		self.format_predictions_raw()
		self.sample_gt_action(ROTATE, acceptance_rate=acceptance_rate)
		self.predictions_raw = sorted(self.predictions_raw, key=lambda x: x.start)  # sort by start time
		self.predictions = [(prediction.shape, prediction.action) for prediction in self.predictions_raw]

	def predict(self, subject, start, end):
		
		slide_pred = slide.evaluate(subject, start, self.state_space_objects, end_time=end)
		if slide_pred:
			RawPrediction(prob=1, subject=subject, action=SLIDE, start=start, end=end)
		
		pick_place_pred = pick_place.evaluate(subject, start, self.state_space_objects, end_time=end)
		if pick_place_pred:
			# if it is a pick and place, then it may contain another object, check that first
			for other_obj in self.obj_names:
				if other_obj != subject:
					contain_pred = contain.evaluate(subject, end, self.state_space_objects, obj=other_obj)
					if contain_pred:
						return RawPrediction(prob=1, subject=subject, obj=other_obj, action=CONTAIN, start=start, end=end)

			return RawPrediction(prob=1, subject=subject, action=PICK_PLACE, start=start, end=end)
		
		# still have to add rotate

		# didn't find any match
		return RawPrediction(prob=1, subject=subject, action=NO_OP, start=start, end=end)

	def composite_actions_raw(self):
		composite_actions_raw = []
		for start_index, subj_pred in enumerate(self.predictions_raw):
			subj_aa = (subj_pred.shape, subj_pred.action)
			for obj_pred in self.predictions_raw[start_index + 1:]:
				# predictions sorted by start time already
				obj_aa = (obj_pred.shape, obj_pred.action)
				if during(subj_pred.start, subj_pred.end, obj_pred.start, obj_pred.end, margin=.1):
					composite_actions_raw.append((subj_aa, _DURING, obj_aa))
				else:
					composite_actions_raw.append((subj_aa, _BEFORE, obj_aa))
		composite_actions_raw = set(composite_actions_raw)
		return composite_actions_raw

	def composite_actions(self, composite_actions_raw):
		ca = []
		for car in composite_actions_raw:
			if car in COMP_ACTIONS:
				ca.append(COMP_ACTIONS[car])
		return set(ca)
		# return set(COMP_ACTIONS[ca] for ca in composite_actions_raw)


# TODO: Extend StateSpaceClassifier
class StateSpaceClassifierSoft:

	def __init__(self, state_space):
		self.state_space = state_space
		self.state_space_objects = self.state_space['objects']
		self.obj_names = list(self.state_space_objects.keys())
		self.predictions = []
		self.predictions_raw = []

	def map_to_classes(self):
		# int_classes = map(lambda ac: CLASSES_MAP[ac], self.predictions)
		int_classes = []
		# print(Counter(self.predictions))
		for prediction in self.predictions:
			if prediction[1] != NO_OP and prediction in CLASSES_MAP:
				int_classes.append(CLASSES_MAP[prediction])
		return set(int_classes)

	def add_contained_slides(self):
		# if object O contains O', and then O slides, then O' slides as well.
		# should techincally be defined as a formal rule as well, but it is more complicated
		self.predictions_raw = sorted(self.predictions_raw, key=lambda x: x.start)  # sort by start time
		for index, prediction in enumerate(self.predictions_raw):
			if prediction.action == CONTAIN:  # subject, obj, CONTAIN
				# while the container subject is sliding, so is the contained object
				for obj_prediction in self.predictions_raw[index + 1:]:
					if obj_prediction.subject == prediction.subject:
						if obj_prediction.action == SLIDE:
							new_pred = RawPrediction(prob=obj_prediction.prob, subject=prediction.obj, action=SLIDE, 
										start=obj_prediction.start, end=obj_prediction.end)
							# print(f'slide added')
							self.predictions_raw.append(new_pred)
						else:
							# if the container performs another action, ie pick-place then the contained object is no longer restricted
							break
		
	# TODO: ideally keep the data structure moving forward
	def format_predictions_raw(self):
		for prediction in self.predictions_raw:
			output_tuple = ((get_shape(prediction.subject), prediction.action), prediction.start, prediction.end)
			data = self.state_space_objects[prediction.subject]
			if 'movement_predictions' in data:
				data['movement_predictions'].append(output_tuple)
			else:
				data['movement_predictions'] = [output_tuple]

	def sample_gt_action(self, action, acceptance_rate=0.6):
		# sample the rotation's since we cannot predict those yet
		# TODO: just have these available when constructing the intial state space, not as an ad-hoc
		# addition
		gt_movements = StateSpaceGT(json.load(open(self.state_space['file_path'])))
		gt_movements.predict_all()
		movements = gt_movements.predictions_raw
		for movement in movements:
			shape, movement_action = movement[0]
			if movement_action == action:
				if uniform(size=1).item() < acceptance_rate:
					if acceptance_rate != 1.0:
						rand_times = randint(-4, high=4, size=2)
						sampled_movement = RawPrediction(prob=acceptance_rate, shape=shape, action=movement_action)
						sampled_movement.start = max(0, movement[1] + rand_times[0])
						sampled_movement.end = min(MAX_FRAMES - 1, movement[2] + rand_times[1])
					else:
						sampled_movement = RawPrediction(prob=1.0, shape=shape, action=movement_action, start=movement[1], end=movement[2])
					self.predictions_raw.append(sampled_movement)

	def predict_all(self, acceptance_rate=0.6, top_k=1):
		for subject, data in self.state_space_objects.items():
			for start, end in data['object_movements']:
				prediction = self.predict(subject, start, end, top_k=top_k)
				self.predictions_raw.extend(prediction)

		self.add_contained_slides()
		self.format_predictions_raw()
		self.sample_gt_action(ROTATE, acceptance_rate=acceptance_rate)
		self.predictions_raw = sorted(self.predictions_raw, key=lambda x: x.start)  # sort by start time
		self.predictions = [(prediction.shape, prediction.action) for prediction in self.predictions_raw]

	def predict(self, subject, start, end, top_k=1):
		solutions = []
		# window = slide.predicates[0][0].get_window(subject, start, self.state_space_objects, end_time=end)
		slide_pred = slide.evaluate_soft(subject, start, self.state_space_objects, end_time=end)
		# print(slide_pred)
		solutions.append(RawPrediction(prob=slide_pred, subject=subject, action=SLIDE, start=start, end=end))

		pick_place_pred = pick_place.evaluate_soft(subject, start, self.state_space_objects, end_time=end)
		solutions.append(RawPrediction(prob=pick_place_pred, subject=subject, action=PICK_PLACE, start=start, end=end))
		
		if pick_place_pred > slide_pred:
			for other_obj in self.obj_names:
				if other_obj != subject:
					contain_pred = contain.evaluate_soft(subject, end, self.state_space_objects, obj=other_obj)
					solutions.append(RawPrediction(prob=contain_pred, subject=subject, obj=other_obj, action=CONTAIN, start=start, end=end))
		
		solutions = sorted(solutions, key=lambda x: x.prob, reverse=True)
		solution = solutions[:top_k]

		# if solution.prob < 0.2:
		# 	solution.action = NO_OP
		return solution

	def composite_actions_raw(self):
		composite_actions_raw = []
		for start_index, subj_pred in enumerate(self.predictions_raw):
			for obj_pred in self.predictions_raw[start_index + 1:]:
				# predictions sorted by start time already
				subj_aa = (subj_pred.shape, subj_pred.action)
				obj_aa = (obj_pred.shape, obj_pred.action)
				# probability = subj_pred.prob * obj_pred.prob
				iou = during_soft(subj_pred.start, subj_pred.end, obj_pred.start, obj_pred.end)
				margin = 0.1
				
				feature = np.zeros(4)
				feature[0] = subj_pred.prob
				feature[1] = iou
				feature[2] = sigmoid(obj_pred.start - subj_pred.end)
				feature[3] = obj_pred.prob
				if iou > margin:
					# probability *= iou
					
					composite_actions_raw.append((feature, (subj_aa, _DURING, obj_aa)))
				else:
					# probability *= sigmoid(obj_pred.start - subj_pred.end)
					composite_actions_raw.append((feature, (subj_aa, _BEFORE, obj_aa)))
		# composite_actions_raw = set(composite_actions_raw)
		return composite_actions_raw

	def composite_actions(self, composite_actions_raw):
		ca = defaultdict(list)
		for p_car in composite_actions_raw:
			p, car = p_car
			if car in COMP_ACTIONS:
				ca[COMP_ACTIONS[car]].append(p)
		for comp_index, probs in ca.items():
			ca[comp_index] = np.mean(probs, axis=0)
		return ca
		# return set(COMP_ACTIONS[ca] for ca in composite_actions_raw)

class StateSpaceGT(StateSpaceClassifier):

	def __init__(self, state_space):
		# pass in the loaded json file as the state_space
		self.state_space = state_space
		self.obj_instances = {}
		self.predictions = []
		self.predictions_raw = []
		self.get_object_names()
		self.prediction_by_object = defaultdict(list)

	def get_object_names(self):
		for obj in self.state_space['objects']:
			self.obj_instances[obj['instance']] = build_unique_label(obj, self.obj_instances)
			# self.obj_instances[obj['instance']] = build_label(obj)

	def predict_all(self):
		for obj_instance, movements in self.state_space['movements'].items():
			# object instance is in the format of "SmoothCone_1" -> cone
			# movement in the form of [action, optional object, start, end]
			subj_str = self.obj_instances[obj_instance]
			shape_str = subj_str.split()[3]
			for movement in movements:
				action = movement[0]
				start = movement[2]
				end = movement[3]
				prediction = (shape_str, action)
				prediction_raw = (prediction, start, end)
				self.predictions.append(prediction)

				prediction_raw = RawPrediction(prob=1.0, shape=shape_str, subject=subj_str, action=action, start=start, end=end)
				if action == CONTAIN:
					# if action is contain then what object instance is it containing
					obj_inst = movement[1]
					obj_str = self.obj_instances[obj_inst]
					prediction_raw.obj = obj_str
				self.predictions_raw.append(prediction_raw)
				self.prediction_by_object[subj_str].append(prediction_raw)

		self.add_contained_slides()
		self.predictions_raw = sorted(self.predictions_raw, key=lambda v: v.start)  # sort by start time

	def compute_composite_actions(self):
		self.predict_all()
		car = self.composite_actions_raw()
		ca = self.composite_actions(car)
		return ca


class DisentangledClassifier(StateSpaceClassifier):
	def __init__(self, state_space, comp_actions):
		super().__init__(state_space)
		self.comp_actions = comp_actions
			
	def composite_actions(self, composite_actions_raw):
		ca = []
		for car in composite_actions_raw:
			if car in self.comp_actions:
				ca.append(self.comp_actions[car])
		return set(ca)

def centroid(bbox):
	upper = bbox[[2, 3]]
	lower = bbox[[0, 1]]
	return np.mean([upper, lower], axis=0)

def is_contained(subj_centroid, obj_centroid, subject, obj_name):
	close = euclidean(subj_centroid, obj_centroid) < 15
	cone = get_shape(subject) == 'cone'
	larger = SIZES.index(get_size(subject)) > SIZES.index(get_size(obj_name))

	return close and cone and larger

class VideoStateSpaceClassifier(StateSpaceClassifierSoft):

	def predict(self, subject, start, end):
		
		actions = self.state_space_objects[subject]['pred_acts'][start: end]
		action_dist = Counter(actions).most_common()
		
		# most common action from interval
		action = action_dist[0][0]

		# prior knowledge to add contain predicitons based on end position of current object
		# and other surrounding objects
		obj = None
		if action == PICK_PLACE:
			subj_centroid = centroid(self.state_space_objects[subject]['object_loc'][end])
			for obj_name, data in self.state_space_objects.items():
				if obj_name != subject:
					obj_centroid = centroid(data['object_loc'][end])
					if is_contained(subj_centroid, obj_centroid, subject, obj_name):
						action = CONTAIN
						obj = obj_name

			
		action_p = action_dist[0][1]/len(actions)
		prediction = RawPrediction(prob=action_p, subject=subject, obj=obj, action=action, start=start, end=end)
		return prediction

	def predict_all(self):
		for subject, data in self.state_space_objects.items():
			for start, end in data['object_movements']:
				prediction = self.predict(subject, start, end)
				self.predictions_raw.append(prediction)

		self.add_contained_slides()
		# self.format_predictions_raw()
		# self.sample_gt_action(ROTATE, acceptance_rate=acceptance_rate)
		self.predictions_raw = sorted(self.predictions_raw, key=lambda x: x.start)  # sort by start time
		self.predictions = [(prediction.shape, prediction.action) for prediction in self.predictions_raw]

if __name__ == '__main__':
	classes = actions_order_dataset(unique=True)
	print(classes)