import json
import glob
import numpy as np
import pickle
from scipy.spatial import distance
from tqdm import tqdm
from pprint import pprint
from collections import defaultdict, Counter
from collections.abc import Iterable
import os

from vision.visionutils import build_label, build_unique_label, MAX_FRAMES, FRAME_SAMPLING_RATE, load_labels, frames_by_object, frames_by_object_vision
# from temporal.predicates import *
from temporal.inference import StateSpaceClassifier, StateSpaceClassifierSoft, DisentangledClassifier, StateSpaceGT, VideoStateSpaceClassifier
from datasets.cater.enums import CLASSES_MAP, ACTION_CLASSES, _BEFORE, _DURING, _AFTER, ROTATE, NO_OP, inv_ord

from temporal.dataset import ActionAggregation
import logging

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

class Scene:
	def __init__(self, scene_id):
		self.id = scene_id
		self.scene = None

	def init_scene(self):
		self.scene['object_loc'] = {}
		self.scene['object_feat'] = {}
		self.scene['image_feat'] = []

# use the ground truth 3d coordinates, but use the predicted image features
class GroundTruthLoader:
	# feature_path = os.path.join(data_dir, 'frame_features_all.pkl')
	_features = None

	def __init__(self, data_dir, split='train'):
		self.data_dir = data_dir
		self.split = split
		self.scenes_path = os.path.join(data_dir, 'scenes/')
		self.scenes = {}
		self.feature_path = None
		# logger.info('loading features')
		# self.features = pickle.load(open(self.feature_path, 'rb'))[self.split]
		# self._features = None

	@property
	def features(self):
		# print('getting features')
		# we actually want a global class features object, that can be accessed by any instantiation
		if GroundTruthLoader._features is None:
			logger.info('loading features')
			GroundTruthLoader._features = pickle.load(open(self.feature_path, 'rb'))
			# aggregate the files into just one flat dictionary, not by train, val etc.
			aggregated = {}
			for split, features in GroundTruthLoader._features.items():
				aggregated.update(features)
			GroundTruthLoader._features = aggregated
			logger.info(f'loaded {len(GroundTruthLoader._features)} feature files')
		return GroundTruthLoader._features

	def load_all(self, samples=None, folder='actions_present'):
		# for file_path in tqdm(glob.glob(os.path.join(self.scenes_path, '*.json'))):
		if samples is None:
			# this file is quite large, only load if we do a full evaluation 5500 samples
			self.feature_path = os.path.join(self.data_dir, 'frame_features_all.pkl')
		else:
			# 1000 total samples
			self.feature_path = os.path.join(self.data_dir, 'frame_features.pkl')
		file_names, _ = load_labels(self.data_dir, self.split, folder=folder)
		logger.info(f'loading {self.split} files in loader')
		for file_name in tqdm(file_names[:samples]):
			self.load_file(file_name)
		self.check_scenes()
				
	# see if some data for each variable loaded for a random object in the scene
	def check_scenes(self):
		parsed_items = defaultdict(int)
		for file, scene in self.scenes.items():
			sample_data = scene['objects'][list(scene['objects'])[0]]
			for item, values in sample_data.items():
				if isinstance(values, Iterable):
					parsed_items[item] += bool(len(values))
				elif isinstance(values, dict):
					if len(values) > 0:
						sample_key = list(values)[0]
						parsed_items[item] += bool(len(values[sample_key]))
				else:
					raise ValueError(f'found scene with {item} of type {type(values)}')
		logger.info(f'counts for items {parsed_items}')

	def load_file(self, file_name):
		file_path = os.path.join(self.scenes_path, file_name.replace('.avi', '.json'))
		# file_id = file_path.split('/')[-1].split('.')[0]
		scene = json.load(open(file_path))
		save_scene = {'file_path': file_path, 'objects': {}}

		# in the perception version, we will use the detected objects
		for obj in scene['objects']:
			# obj_label = build_label(obj)
			obj_label = build_unique_label(obj, save_scene['objects'])
			# need to formulate a way to handle duplicate objects
			# if obj_label in save_scene['objects']:
			# 	obj_label += ' 0'
			save_scene['objects'][obj_label] = {}
			save_scene['objects'][obj_label]['object_loc'] = self.get_locations(obj)
			save_scene['objects'][obj_label]['object_feat'] = []
			save_scene['objects'][obj_label]['image_feat'] = []

		self.get_features(file_name, scene, save_scene)
		self.post_process(save_scene)
		self.scenes[file_name] = save_scene

	def location_array(self, locations):
		locations_arr = []
		for frame_num in range(MAX_FRAMES):
			locations_arr.append(locations[str(frame_num)])
		return np.array(locations_arr)
	
	def get_features(self, file_name, scene, save_scene):
		if file_name not in self.features:
			return
		frames = self.features[file_name]
		object_features = frames_by_object(frames)

		for obj_name, data in save_scene['objects'].items():
			roi_features = data['object_feat']
			image_features = data['image_feat']
			if obj_name in object_features:
				for frame_feature in object_features[obj_name]:
					object_feat = frame_feature.get('roi_features', None)
					image_feat = frame_feature.get('image_feature', None)
					for _ in range(FRAME_SAMPLING_RATE):
						if len(roi_features) < MAX_FRAMES:
							roi_features.append(object_feat)
						if len(image_features) < MAX_FRAMES:
							image_features.append(image_feat)
				# truncate the last few frames, as the last sampled image should not be duplicated
				# avoid this as we manipulate the save scene in
				# roi_features = roi_features[:MAX_FRAMES]
				# image_features = image_features[:MAX_FRAMES]

	def get_locations(self, obj):
		return self.location_array(obj['locations'])

	def post_process(self, save_scene):
		return

# load the ground truth movement data to test our methods
class GroundTruthMovementLoader:

	def __init__(self, data_dir, split='train'):
		self.data_dir = data_dir
		self.split = split
		self.scenes_path = os.path.join(data_dir, 'scenes/')
		self.scenes = {}

	def load_all(self, samples=None, folder='actions_present'):
		file_names, _ = load_labels(self.data_dir, self.split, folder=folder)
		for file_name in tqdm(file_names[:samples]):
			json_file = file_name.replace('.avi', '.json')
			file_path = os.path.join(self.scenes_path, json_file)
			scene = json.load(open(file_path))
			self.scenes[file_name] = scene

# load the preivously predicted atomic actions
class VisionLoader(GroundTruthLoader):
	def __init__(self, data_dir, split='train'):
		super().__init__(data_dir, split=split)
		# objects and their corresponding events per frame
		self.detections_path = os.path.join(data_dir, 'structured_scene/')

	def load_all(self, samples=None, folder='actions_present'):
		# for file_path in tqdm(glob.glob(os.path.join(self.scenes_path, '*.json'))):
		file_names, _ = load_labels(self.data_dir, self.split, folder=folder)
		logger.info(f'loading {self.split} files in loader')
		
		# load process and save all video detections
		for file_name in tqdm(file_names[:samples]):
			self.load_file(file_name)
		self.check_scenes()

	def ins_name_to_attribute(self, file_name, detections):
		file_path = os.path.join(self.scenes_path, file_name.replace('.avi', '.json'))
		scene = json.load(open(file_path))

		object_names = []
		for obj in scene['objects']:
			ins_name = obj['instance']
			obj_name = build_unique_label(obj, object_names)
			object_names.append(obj_name)
			if ins_name in detections:
				detections[obj_name] = detections[ins_name]
				del detections[ins_name]

	def interpolate_boxes(self, data):
		start_index = 0
		end_index = 1
		while start_index < MAX_FRAMES - 1:
			if not np.any(data['bboxes'][start_index]):
				while end_index < MAX_FRAMES - 1:
					if np.any(data['bboxes'][end_index]):
						break
					end_index += 1
				end_loc = data['bboxes'][end_index] if np.any(data['bboxes'][end_index]) else data['bboxes'][start_index - 1]
				start_loc = data['bboxes'][start_index - 1] if start_index > 0 else data['bboxes'][end_index]
				location_change = end_loc - start_loc
				step = location_change/(end_index - start_index)
				middle_location = start_loc + step
				for middle_index in range(start_index, end_index):
					data['bboxes'][middle_index] = middle_location
					middle_location += step

				start_index = end_index + 1
				end_index = start_index + 1
			else:
				start_index += 1
				end_index += 1

	def last_known_boxes(self, data):
		for index, box in enumerate(data['bboxes']):
			if not np.any(box):
				data['bboxes'][index] = data['bboxes'][index-1]

	def interpolate_missing_data(self, detections, interpolate=True):
		for obj_name, data in detections.items():
			# missing actions will just be no_op
			data['pred_acts'] = [act if act != None else NO_OP for act in data['pred_acts']]

			# interpolate any missing bboxes			
			data['bboxes'] = np.array([bbox if bbox is not None else [0, 0, 0, 0] for bbox in data['bboxes']], dtype=np.float32)
			if interpolate:
				self.interpolate_boxes(data)
			else:
				self.last_known_boxes(data)

	def load_file(self, file_name):
		file_path = os.path.join(self.detections_path, file_name.replace('.avi', '.json'))
		if not os.path.exists(file_path):
			return
		# processed detections = {'ins_names', 'bboxes', 'pred_acts', 'gt_acts'}
		detections = json.load(open(file_path))

		# group detections, and event predictions per object
		detections = frames_by_object_vision(detections)

		# convert split object names into attributes for ease of use
		self.ins_name_to_attribute(file_name, detections)

		# fill in the missign detections if necessary
		# for missing bounding boxes use the last known location or interpolate
		self.interpolate_missing_data(detections, interpolate=False)

		save_scene = {'file_path': file_path, 'objects': {}}

		# in the perception version, we will use the detected objects
		for obj_label, data in detections.items():
			save_scene['objects'][obj_label] = {}
			save_scene['objects'][obj_label]['object_loc'] = data['bboxes']
			save_scene['objects'][obj_label]['pred_acts'] = data['pred_acts']
			save_scene['objects'][obj_label]['gt_acts'] = data['gt_acts']

		self.scenes[file_name] = save_scene

class PerturbedLoader(GroundTruthLoader):
	def __init__(self, data_dir, split='train', mean=0.0, std=0.01):
		super().__init__(data_dir, split=split)
		self.mean = mean
		self.std = std

	def get_locations(self, obj):
		#TODO: prevent location from updating, when the object is not detected
		ground_truth_locations = self.location_array(obj['locations'])
		np_loc = np.array(ground_truth_locations)
		np_loc_pert = np_loc + np.random.normal(self.mean, self.std, size=np_loc.shape)
		return np_loc_pert

	def post_process(self, save_scene):
		for obj_name, data in save_scene['objects'].items():
			prev_obj_location = data['object_loc'][0]
			for index, obj_feature in enumerate(data['object_feat']):
				# could not detect it, possibly contained, or occluded
				if obj_feature is None:
					data['object_loc'][index] = prev_obj_location
				else:
					prev_obj_location = data['object_loc'][index]

# TODO: decouple the StateSpaces for each file, make another class StateSpaceHandler that contains all state spaces
class StateSpace:
	def __init__(self, loader, distance_change=1e-2, smooth=True):
		self.loader = loader
		self.scenes = self.loader.scenes
		self.distance_change = distance_change
		self.smooth = smooth
		self.detect_movements()

	def detect_movements(self):
		logger.info('parsing object movements')
		for file_name, scene in tqdm(self.scenes.items()):
			for obj_name, data in scene['objects'].items():
				movements = self.movements_by_distance(data)
				movements = self.smooth_movements(movements)
				movements = self.group_movements(movements)
				data['object_movements'] = movements

	def movements_by_distance(self, data):
		locations = data['object_loc']
		movements = []
		
		if self.smooth:
			smooth_range = 4
			smooth_locations = np.array([np.mean(locations[max(0, i-smooth_range): min(len(locations), i+smooth_range + 1)], axis=0) for i in range(len(locations))])
		else:
			smooth_locations = locations
		prev_location = smooth_locations[0].copy()
		moving = False
		# distances = [distance.euclidean(prev, nxt) for prev, nxt in zip(locations[:-1], locations[1:])]
		# self.distance_change = np.mean(np.sort(distances)[:15])  # set thresh based on occurring noise
		for time, location in enumerate(smooth_locations[1:]):
			if distance.euclidean(location, prev_location) > self.distance_change:
				if not moving:
					movements.append(time)
					moving = True
			else:
				if moving:
					movements.append(time)
					moving = False
			prev_location = location.copy()
		if len(movements) % 2 != 0:
			# still moving till the end of the video, add final time index
			movements.append(MAX_FRAMES - 1)
		return movements

	def smooth_movements(self, movements, time_window=1):
		# smooth the movements when the movement windows have a pause in the location changes
		# ie, if locations are [4, 4, 5, 6, 7, 8, 8, 9, 10, 11, 11]
		# then we record two movements from time [1, 5, 6, 8] even though the movement may be from [1, 8]
		m = np.array(movements)
		diff = np.diff(m)
		del_idx = np.argwhere(diff <= time_window)  # first index within the (t, t + window) to delete ie 5
		del_idx = np.vstack([del_idx, del_idx + 1]) # second index to delete, ie 6 above
		m = np.delete(m, del_idx)
		if len(m) % 2 != 0:
			# m = np.hstack([m, movements[-1]])  # if removed an odd number of values, stack an extra value
			m = m[:-1]
		assert len(m) % 2 == 0, (len(movements), len(m), movements, m)
		return m

	def group_movements(self, movements):
		return movements.reshape((-1, 2))
		
	def detect_feature_changes(self):
		#TODO: need to find a way to detect movement given a static location, maybe try some of the optical flow methods
		raise NotImplementedError

	def get_classifier(self, state_space, soft=False):
		if soft:
			return StateSpaceClassifierSoft(state_space)
		else:
			return StateSpaceClassifier(state_space)

	def infer_classes(self, acceptance_rate=0.7, soft=False, top_k=1):
		logger.info('predicting actions')
		for file_name, state_space in tqdm(self.scenes.items()):
			model = self.get_classifier(state_space, soft=soft)
			if soft:
				model.predict_all(acceptance_rate=acceptance_rate, top_k=top_k)
			else:
				model.predict_all(acceptance_rate=acceptance_rate)
			state_space['predictions'] = model.map_to_classes()
			state_space['predictions_raw'] = model.predictions_raw
			state_space['composite_actions_raw'] = model.composite_actions_raw()
			state_space['composite_actions'] = model.composite_actions(state_space['composite_actions_raw'])

	def view_gt(self):
		logger.info('adding ground truth actions')
		for file_name, state_space in tqdm(self.scenes.items()):
			file_path = os.path.join(self.loader.scenes_path, file_name.replace('.avi', '.json'))
			gt_json = json.load(open(file_path))
			gt_predictions = StateSpaceGT(gt_json)
			gt_predictions.predict_all()
			for obj_name, data in state_space['objects'].items():
				if obj_name in gt_predictions.prediction_by_object:
					data['gt_predictions_raw'] = gt_predictions.prediction_by_object[obj_name]

	def view_errors(self):
		overestimations = Counter()
		underestimations = Counter()

		for video, data in self.scenes.items():
			for obj_name, obj_data in data['objects'].items():
				predictions = set()
				gt = set()
				if 'gt_predictions_raw' in obj_data and 'movement_predictions' in obj_data:
					for action in obj_data['gt_predictions_raw']:
						gt.add(action[0])
					for action in obj_data['movement_predictions']:
						predictions.add(action[0])
				overestimations.update(predictions - gt)
				underestimations.update(gt - predictions)

		logger.info('false positives')
		logger.info(overestimations.most_common())
		logger.info('true negatives')
		logger.info(underestimations.most_common())

	def pop_movement(self, action, movements):
		best_index = 0
		best_intersection = 0
		best_union = 0
		best_aa = None

		aa, start, end = action
		for index, movement in enumerate(movements):
			comp_aa, comp_start, comp_end = movement
			if comp_start > end or start > comp_end:
				continue
			intersection = min(end, comp_end) - max(start, comp_start)
			union = max(end, comp_end) - min(start, comp_start)

			if intersection > best_intersection:
				best_intersection = intersection
				best_union = union
				best_index = index
				best_aa = comp_aa
		# match without replacement
		del movements[best_index]
		return aa, best_aa, best_intersection, best_union

	def view_iou_accuracy(self):
		intersection = 0
		union = 0
		class_predictions = defaultdict(list)
		for video, data in self.scenes.items():
			for obj_name, obj_data in data['objects'].items():
				if 'gt_predictions_raw' in obj_data and 'movement_predictions' in obj_data:
					movements = obj_data['movement_predictions'].copy()
					for action in obj_data['gt_predictions_raw']:
						if len(movements) and action[0][1] not in [NO_OP]:
							aa, best_aa, best_intersection, best_union = self.pop_movement(action, movements)
							intersection += best_intersection
							union += best_union
							class_predictions[aa].append(aa == best_aa)
		print(f'Temporal IoU: {intersection/union}')
		class_scores = {aa: np.mean(preds) for aa, preds in class_predictions.items()}
		pprint(class_scores)
		print(f'Average class recall: {np.mean(list(class_scores.values()))}')

class VideoStateSpace(StateSpace):
	def __init__(self, loader):
		self.loader = loader
		self.scenes = self.loader.scenes
		# get the event intervals, useful for later
		self.detect_movements()

	def detect_movements(self):
		logger.info('parsing object movements')
		for file_name, scene in tqdm(self.scenes.items()):
			for obj_name, data in scene['objects'].items():
				movements = self.get_event_intervals(data, end_buffer=3)
				movements = self.filter_event_intervals(movements, min_window=3)
				data['object_movements'] = movements

	def get_event_intervals(self, data, end_buffer=3):
		actions = data['pred_acts']

		# collect start and end times for each interval event, with end_buffer buffer 
		# for missed detections
		movements = []
		moving = False

		time = 0
		last_index = MAX_FRAMES - 1
		while time < len(actions):
			action = actions[time]

			# start recording a event sequence with the start time
			if action != NO_OP:
				if not moving:
					movements.append(time)
					moving = True
				time += 1
			else:
				# if in the middle of an event sequence
				if moving:
					# if exists an even in the next end_buffer, continue
					for step in range(1, end_buffer + 1):
						next_action = actions[min(time + step, last_index)]
						if next_action != NO_OP:
							time += step
							break
					# else end the sequence now
					else:
						movements.append(time)
						moving=False
						time += 1
				else:
					time += 1
		
		# if event goes till the end of the video, then end it
		if len(movements) % 2 != 0:
			movements.append(last_index)
		return np.array(movements)

	def filter_event_intervals(self, movements, min_window=3):
		movements = movements.reshape((-1, 2))  # (events, start/end times)
		lengths = movements[:, 1] - movements[:, 0]  # event durations

		# return events only longer than the min_window
		return movements[np.argwhere(lengths > min_window).flatten()]
	
	def get_classifier(self, state_space):
		return VideoStateSpaceClassifier(state_space)

	def infer_classes(self, composite=True):
		logger.info('predicting actions')
		for file_name, state_space in tqdm(self.scenes.items()):
			model = self.get_classifier(state_space)
			model.predict_all()
			state_space['predictions'] = model.map_to_classes()  # class integer labels
			state_space['predictions_raw'] = model.predictions_raw  # class string labels
			if composite:
				state_space['composite_actions_raw'] = model.composite_actions_raw()  # enumerate composite events
				state_space['composite_actions'] = model.composite_actions(state_space['composite_actions_raw'])
	
class VideoStateSpaceGT(VideoStateSpace):
	def __init__(self, loader):
		self.loader = loader
		self.scenes = self.loader.scenes

	def get_classifier(self, state_space):
		return StateSpaceGT(state_space)

	def infer_classes(self, composite=True):
		logger.info('predicting actions')
		atomic_events = Counter()
		for file_name, state_space in tqdm(self.scenes.items()):
			model = self.get_classifier(state_space)
			model.predict_all()
			state_space['predictions'] = model.map_to_classes()  # class integer labels
			state_space['predictions_raw'] = model.predictions_raw  # class string labels

			for prediction in model.predictions_raw:
				atomic_events[(prediction.shape, prediction.action)] += 1

		# print(atomic_events)

# MAP estimate
class DisentangledTraining:

	def __init__(self, state_space, labels):
		aa = ActionAggregation(state_space, labels)
		aa.aggregate()
		self.action_features = aa.action_features
		self.composite_actions = defaultdict(list)

	def get_relation(self, time, time_2):
		if time > time_2:
			relation = _AFTER
		elif time < time_2:
			relation = _BEFORE
		else:
			relation = _DURING
		return relation

	def add_comp_action(self, action_cands, label):
		max_search_iter = 10
		agg_action_feature = self.action_features[label]
		for trial in range(max_search_iter):
			action, time = action_cands[trial]
			action_value = agg_action_feature[action, time]
			action_2, time_2 = action_cands[trial + 1]
			action_2_value = agg_action_feature[action_2, time_2]

			relation = self.get_relation(time, time_2)
			action = ACTION_CLASSES[action]
			action_2 = ACTION_CLASSES[action_2]
			composite_action = (action, relation, action_2)
			composite_action_reverse = (action_2, inv_ord[relation], action)
			composite_value = action_value + action_2_value

			#if composite_action not in self.composite_actions: # or trial == max_search_iter - 1:
			self.composite_actions[composite_action].append((composite_value, label))
			self.composite_actions[composite_action_reverse].append((composite_value, label))
				# break
	def filter_comp_action(self):
		for ca, labels in self.composite_actions.items():
			sorted_labels = sorted(self.composite_actions[ca])
			val, label = sorted_labels[0]
			self.composite_actions[ca] = label

	def map_labels(self, window_size):
		for label, features in self.action_features.items():
			# want an even dimension for the reshape, 301 -> 300 time
			action_feature = np.sum(features, axis=0)[:, :-1]
			
			# accumulate the occurrences within each consecutive window for a better estimate of the
			# temporal relations
			agg_action_feature = np.sum(action_feature.reshape(len(ACTION_CLASSES), -1, window_size), axis=-1)
			self.action_features[label] = agg_action_feature

			# find individual actions and counts over time
			order = np.argsort(-agg_action_feature.flatten())
			# unflatten the index and zip the axes
			action_cands = np.dstack(np.unravel_index(order, agg_action_feature.shape))[0]

			self.add_comp_action(action_cands, label)
		self.filter_comp_action()
				
		logger.info(f'{len(self.composite_actions)} composite actions detected')

class StateSpaceDisentangled(StateSpace):
	def __init__(self, loader, labels, distance_change=1e-2):
		super().__init__(loader, distance_change=distance_change)
		self.file_names, self.labels = labels
		self.disentangle = False
		self.infer_classes()
		self.composite_actions = defaultdict(Counter)
		self.map_labels()

		
		self.disentangle = True
		# dt = DisentangledTraining(self, labels)
		# dt.map_labels(25)
		# self.composite_actions = dt.composite_actions

	def get_classifier(self, state_space):
		if self.disentangle:
			return DisentangledClassifier(state_space, self.composite_actions)
		else:
			return StateSpaceClassifier(state_space)

	def map_labels(self):
		for file_name, labels in zip(self.file_names, self.labels):
			if file_name in self.scenes:
				scene = self.scenes[file_name]
				for comp_action in scene['composite_actions_raw']:
					self.composite_actions[comp_action].update(labels)
		for comp_action, counts in self.composite_actions.items():
			map_label = counts.most_common(1)[0][0]
			# for map_labl_cand in map_label_cands:

			self.composite_actions[comp_action] = map_label

