import torch
from torch.utils.data import Dataset
from collections import defaultdict

from temporal.inference import ACTION_CLASSES, COMP_ACTIONS, StateSpaceGT, ORDERING
from datasets.cater.enums import MAX_FRAMES, SHAPES
from datasets.cater.generate import MAX_SHAPE

import logging
import numpy as np

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

INPUT_DIM = len(ACTION_CLASSES)
OUTPUT_DIM = len(set(COMP_ACTIONS.values()))
REL_DIM = len(ORDERING)
N_TIMELINES = 30 # max number of objects we will track

class TemporalActions(Dataset):
	# direct atomic action mapping to composite action using predicted state spaces 14 -> 301
	def __init__(self, state_space, labels):
		self.state_space = state_space
		self.files, self.actions = labels
		self.state_space_key = 'predictions'
		self.input_dim = INPUT_DIM

	def __len__(self):
		return len(self.files)
	
	def __getitem__(self, idx):
		file_name = self.files[idx]
		# sometimes we only load partial data, for debugging purposes
		if file_name not in self.state_space.scenes:
			file_name = list(self.state_space.scenes)[0]
		inputs = list(self.state_space.scenes[file_name][self.state_space_key])
		outputs = list(self.actions[idx])

		input_tensor = torch.zeros(self.input_dim)
		for t1_class in inputs:
			input_tensor[t1_class] = 1
		output_tensor = torch.zeros(OUTPUT_DIM)
		for t2_class in outputs:
			output_tensor[t2_class] = 1

		return input_tensor, output_tensor

class TemporalActionsGT(Dataset):
	# direct atomic action mapping to composite action 14 -> 301
	def __init__(self, inputs, labels):
		self.input_files, self.input_actions = inputs
		self.files, self.actions = labels

	def __len__(self):
		return len(self.files)
	
	def __getitem__(self, idx):
		file_name = self.files[idx]
		inputs = list(self.input_actions[self.input_files.index(file_name)])
		outputs = list(self.actions[idx])

		input_tensor = torch.zeros(INPUT_DIM)
		for t1_class in inputs:
			input_tensor[t1_class] = 1
		output_tensor = torch.zeros(OUTPUT_DIM)
		for t2_class in outputs:
			output_tensor[t2_class] = 1

		return input_tensor, output_tensor

class LogicCompositeActions(TemporalActions):
	# enuemrate the 301 class labels, evalute on state space, and map to training labels 301 -> 301
	def __init__(self, state_space, labels):
		super().__init__(state_space, labels)
		self.state_space_key = 'composite_actions'
		self.input_dim = OUTPUT_DIM

class LogicCompositeActionsSoft(LogicCompositeActions):

	def __getitem__(self, idx):
		file_name = self.files[idx]
		# sometimes we only load partial data, for debugging purposes
		if file_name not in self.state_space.scenes:
			file_name = list(self.state_space.scenes)[0]
		inputs = self.state_space.scenes[file_name][self.state_space_key]
		outputs = list(self.actions[idx])

		input_tensor = torch.zeros(self.input_dim * 4)
		for t1_class, feature in inputs.items():
			input_tensor[t1_class * 4: (t1_class + 1) * 4] = torch.from_numpy(feature)
		output_tensor = torch.zeros(OUTPUT_DIM)
		for t2_class in outputs:
			output_tensor[t2_class] = 1

		return input_tensor, output_tensor

class LogicCompositeActionsGT(TemporalActions):
	# enuemrate the 301 class labels, evalute on ground truth data, and map to training labels 301 -> 301
	def __init__(self, state_space, labels):
		self.state_space = state_space
		self.files, self.actions = labels
		self.input_dim = OUTPUT_DIM

	def __len__(self):
		return len(self.files)
	
	def __getitem__(self, idx):
		file_name = self.files[idx]
		inputs = list(self.state_space[file_name])
		outputs = list(self.actions[idx])

		input_tensor = torch.zeros(self.input_dim)
		for t1_class in inputs:
			input_tensor[t1_class] = 1
		output_tensor = torch.zeros(OUTPUT_DIM)
		for t2_class in outputs:
			output_tensor[t2_class] = 1

		return input_tensor, output_tensor

def min_max_norm(vector):
	min_v = torch.min(vector)
	range_v = torch.max(vector) - min_v
	if range_v > 0:
		normalised = (vector - min_v) / range_v
	else:
		normalised = torch.zeros(vector.size())
	return normalised

def bit_features_sum(inputs):
	input_tensor = torch.zeros(INPUT_DIM * MAX_FRAMES)
	sum_tensor = torch.zeros(len(ACTION_CLASSES))
	temporal_positions = torch.arange(0, 1, step=1/MAX_FRAMES)
	for prediction_raw in inputs:
		atomic_action, start, end = prediction_raw
		if atomic_action in ACTION_CLASSES:
			index = ACTION_CLASSES.index(atomic_action)
			offset = index * MAX_FRAMES
			tensor_start = start + offset
			tensor_end = end + offset
			input_tensor[tensor_start: tensor_end] = 1
			sum_tensor[index] += torch.sum(temporal_positions[start: end])
	return torch.cat([input_tensor, min_max_norm(sum_tensor)])

def bit_features(inputs):
	input_tensor = torch.zeros(INPUT_DIM * MAX_FRAMES)
	for prediction_raw in inputs:
		atomic_action = (prediction_raw.shape, prediction_raw.action)
		if atomic_action in ACTION_CLASSES:
			index = ACTION_CLASSES.index(atomic_action)
			offset = index * MAX_FRAMES
			tensor_start = prediction_raw.start + offset
			tensor_end = prediction_raw.end + offset
			input_tensor[tensor_start: tensor_end] = 1
	return input_tensor

def bit_features_soft(inputs, sampling_rate):
	input_tensor = torch.zeros(INPUT_DIM * MAX_FRAMES)
	for prediction_raw in inputs:
		atomic_action = (prediction_raw.shape, prediction_raw.action)
		if atomic_action in ACTION_CLASSES:
			index = ACTION_CLASSES.index(atomic_action)
			offset = index * MAX_FRAMES
			tensor_start = prediction_raw.start + offset
			tensor_end = prediction_raw.end + offset
			input_tensor[tensor_start: tensor_end] = prediction_raw.prob
	input_tensor = input_tensor[::sampling_rate]
	return input_tensor

def bit_features_timelines(inputs):
	input_tensor = torch.zeros(N_TIMELINES, INPUT_DIM, MAX_FRAMES)
	obj_idx_map = {}
	for prediction_raw in inputs:
		if prediction_raw.subject not in obj_idx_map:
			obj_idx_map[prediction_raw.subject] = len(obj_idx_map)
		
		obj_idx = obj_idx_map[prediction_raw.subject]
		atomic_action = (prediction_raw.shape, prediction_raw.action)
		if atomic_action in ACTION_CLASSES:
			index = ACTION_CLASSES.index(atomic_action)
			input_tensor[obj_idx, index, prediction_raw.start: prediction_raw.end] = prediction_raw.prob
	return input_tensor

def bit_features2d(inputs):
	input_tensor = torch.zeros(INPUT_DIM, MAX_FRAMES)
	for prediction_raw in inputs:
		atomic_action = (prediction_raw.shape, prediction_raw.action)
		if atomic_action in ACTION_CLASSES:
			index = ACTION_CLASSES.index(atomic_action)
			input_tensor[index, prediction_raw.start: prediction_raw.end] = prediction_raw.prob
	return input_tensor

class TemporalFeatures(Dataset):
	
	# check if run into any loading issues of the data files, and then send one warning if it occurs
	has_warned = False
	
	# action x time mapping to composite action using predicted state spaces 14 * 301 -> 301
	def __init__(self, state_space, labels, mode='1d', sampling_rate=1, output_dim=OUTPUT_DIM):
		self.state_space = state_space
		self.files, self.actions = labels
		self.state_space_key = 'predictions_raw'
		self.mode = mode
		self.sampling_rate = sampling_rate
		self.output_dim = output_dim

	def __len__(self):
		return len(self.files)
	
	def get_inputs(self, file_name):
		return list(self.state_space.scenes[file_name][self.state_space_key])

	def get_input_tensor(self, inputs):
		if self.mode == '1d':
			input_tensor = bit_features(inputs)
		elif self.mode == '2d':
			input_tensor = bit_features2d(inputs)
		elif self.mode == 'soft':
			input_tensor = bit_features_soft(inputs, self.sampling_rate)
		elif self.mode == 'timelines':
			input_tensor = bit_features_timelines(inputs)
		else:
			raise ValueError(f'unknown feature mode {self.mode}')
		return input_tensor

	def __getitem__(self, idx):
		file_name = self.files[idx]
		
		# sometimes we only load partial data, for debugging purposes
		if file_name not in self.state_space.scenes:
			# if not TemporalFeatures.has_warned:
			# 	logger.warning(f'file name {file_name} not in loaded scenes, picking random one')
			# 	logger.warning('this warning will now be supressed')
			# 	TemporalFeatures.has_warned = True
			file_name = list(self.state_space.scenes)[0]

		# generate \mathbf{M_T} input feature matrix
		inputs = self.get_inputs(file_name)
		input_tensor = self.get_input_tensor(inputs)

		# output labels
		outputs = list(self.actions[idx])
		output_tensor = torch.zeros(self.output_dim)
		for t2_class in outputs:
			output_tensor[t2_class] = 1

		return input_tensor, output_tensor

class TemporalFeaturesGT(TemporalFeatures):
	def get_inputs(self, file_name):
		return self.state_space[file_name]

class ActionAggregation(TemporalFeatures):
	def __init__(self, state_space, labels):
		self.state_space = state_space
		self.files, self.actions = labels
		self.action_features = defaultdict(list)
		self.state_space_key = 'predictions_raw'

	def get_inputs(self, file_name):
		return list(self.state_space.scenes[file_name][self.state_space_key])

	def aggregate(self):
		for file_name, labels in zip(self.files, self.actions):
			if file_name in self.state_space.scenes:
				inputs = self.get_inputs(file_name)
				feature = bit_features2d(inputs)
				for label in labels:
					self.action_features[label].append(feature.numpy())


class GeneratedFeatures(Dataset):
	def __init__(self, timelines, output_dim=OUTPUT_DIM, n_timelines= len(SHAPES) * MAX_SHAPE, mode='agg_after'):
		self.timelines = timelines
		self.n_timelines = n_timelines
		self.output_dim = output_dim
		self.mode = mode

	def __len__(self):
		return len(self.timelines)

	def format_input(self, ts):
		# flatten all objects into one timeline before to avoid adding overlapping time indices
		if self.mode == 'agg_before':
			input_tensor = torch.zeros(1, INPUT_DIM, MAX_FRAMES)
		
		# else handle the object aggregation downstream in the relation modules
		elif self.mode == 'agg_after':
			input_tensor = torch.zeros(self.n_timelines, INPUT_DIM, MAX_FRAMES)
		shape_id_map = {}  # store each shape_id into its unique timeline
		for obj_idx, (shape, obj_ts) in enumerate(ts.items()):
			for movements in obj_ts:
				shape_id, action, start, end, probs = movements
				
				if self.mode == 'agg_after':
					if shape_id not in shape_id_map:
						shape_id_map[shape_id] = len(shape_id_map)

					shape_idx = shape_id_map[shape_id]
				else:
					# in agg before all objects are in the same timeline. 
					shape_idx = 0
				
				atomic_action = (shape, action)
				atomic_index = ACTION_CLASSES.index(atomic_action)
				
				if self.mode == 'agg_after':
					prob_tensor = torch.tensor(probs)
				else:
					prob_tensor = torch.max(input_tensor[shape_idx, atomic_index, start: end], torch.tensor(probs))
				
				# input_tensor[shape_idx, atomic_index, start: end] = prob_tensor
				input_tensor[shape_idx, atomic_index] = torch.max(input_tensor[shape_idx, atomic_index], prob_tensor)
				if end > 300:
					print(end)
		return input_tensor

	def __getitem__(self, idx):
		sample = self.timelines[idx]
		inputs, outputs = sample
		input_tensor = self.format_input(inputs)

		output_tensor = torch.zeros(self.output_dim)
		output_tensor[outputs] = 1

		return input_tensor, output_tensor


