import torch
import torch.nn as nn

from collections import defaultdict
from itertools import combinations

from datasets.cater.enums import actions_order_dataset, reverse
from datasets.cater.n_predicate_labels import build_two_pred
from datasets.cater.generate import generate_rules

import logging

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

ETA = 1e-5

class ReasoningNetworkVariable(nn.Module):
	def __init__(self, atomic_events, relations, num_comp_events, max_rule_len, attn_cand_beam=10, variable_len=True, \
		max_rules_beam=100, sparse_combs=True):
		super().__init__()
		self.atomic_events = atomic_events
		self.relations = relations
		self.num_labels = num_comp_events
		self.max_rule_len = max_rule_len
		self.attn_cand_beam = attn_cand_beam
		self.variable_len = variable_len
		self.max_rules_beam = max_rules_beam
		self.sparse_combs = sparse_combs  # use sparse tensors to represent the combination matrices

		self.unique_predicates = len(self.atomic_events)**2 * len(self.relations)
		self.combs_per_label = None  # compute later
		self.batch_size = None

		# pre determine size needed for parameters
		self.combs_per_label = self.generate_comb_matrix(search_indices=range(self.attn_cand_beam)).size(1)

		# select which combinatorial outputs to use for each label
		select = torch.ones(self.num_labels * self.combs_per_label, dtype=torch.float) * 1/self.max_rule_len
		self.select = nn.Parameter(select, requires_grad=True)

		# for checking gt evaluation
		self.rules = generate_rules(n_predicates=self.max_rule_len, variable_len=self.variable_len, max_rules_beam=self.max_rules_beam)

		# for rule construction
		if atomic_events is None:
			self.grnd_pred_mapping = actions_order_dataset(n=2, unique=False)
		else:
			self.grnd_pred_mapping = actions_order_dataset(atomic_events=atomic_events, n=2, unique=False)

	def gen_combs(self, attention_weights=None, compute_stats=True):
		# self.num_labels = attention_weights.size(0)

		# for the existing trained attion between composite events -> labels, choose the top beam highest weighted event indices
		# for the event indices enumerate all possibilities up to max_rule_len
		sorted_weights = torch.flip(torch.argsort(torch.abs(attention_weights), dim=1), dims=[1])
		fact_indices = self.get_fact_indices(sorted_weights)

		if compute_stats:
			# check the derived fact indicies against the ground truth fact predicates to see if the model training and beam captured the right ones
			self.compute_stats(fact_indices, sorted_weights)

		self.get_comb_matrix(fact_indices)
		# self.get_comb_matrix_sparse(fact_indices)

	def compute_stats(self, fact_indices, sorted_weights):
		mean_recall = self.check_comb_recall(fact_indices)
		mean_recip_rank = self.check_comb_mrr(sorted_weights)
		logger.info(f'Rule combination recall: {mean_recall}')
		logger.info(f'Rule predicate MRR: {mean_recip_rank}')

	def get_comb_matrix(self, fact_indices):
		label_comb_matrices = []
		for indices in fact_indices:
			label_comb_matrix = self.generate_comb_matrix(search_indices=indices)
			label_comb_matrices.append(label_comb_matrix)
			if self.combs_per_label is None:
				self.combs_per_label = label_comb_matrix.size(1)
		self.label_comb_matrix = torch.cat(label_comb_matrices, dim=1)  # (unique_preds, combs_per_label * num_labels)
		self.batch_comb_matrix = None  # cache batched version for later, @setter

	def get_comb_matrix_sparse(self, fact_indices):
		label_comb_matrices = []
		for indices in fact_indices:
			label_comb_matrix = self.generate_comb_matrix_sparse(search_indices=indices)
			label_comb_matrices.append(label_comb_matrix)
			if self.combs_per_label is None:
				self.combs_per_label = label_comb_matrix.size(1)
		self.label_comb_matrix = torch.cat(label_comb_matrices, dim=1)  # (unique_preds, combs_per_label * num_labels)

	def get_fact_indices(self, sorted_weights):
		fact_indices = []
		for weight_indices in sorted_weights:
			label_indices = []
			for idx in weight_indices:
				if len(label_indices) == self.attn_cand_beam:
					break
				# don't have reverse relations, check before adding
				grnd_pred = self.grnd_pred_mapping[idx]
				rev_grnd_pred = reverse(grnd_pred)
				rev_idx = self.grnd_pred_mapping.index(rev_grnd_pred)
				if rev_idx not in label_indices:
					label_indices.append(idx.item())
			fact_indices.append(label_indices)
		return fact_indices

	def check_comb_recall(self, fact_indices):
		grnd_pred_mapping = actions_order_dataset(n=2, unique=False)
		recall = defaultdict(list)
		for rule, label_idx in self.rules.items():
			rule_len = len(rule)
			grnd_pred_indices = fact_indices[label_idx]
			for rule_predicate in rule:
				rule_predicate_idx = grnd_pred_mapping.index(rule_predicate)
				rev_rule_predicate_idx = grnd_pred_mapping.index(reverse(rule_predicate))
				if rule_predicate_idx not in grnd_pred_indices and rev_rule_predicate_idx not in grnd_pred_indices:
					recall[rule_len].append(False)
					break
			else:
				recall[rule_len].append(True)
		mean_recall = {f'rule len {rule_len}': sum(correct)/len(correct) for rule_len, correct in recall.items()}
		return mean_recall

	def check_comb_mrr(self, sorted_weights):
		grnd_pred_mapping = actions_order_dataset(n=2, unique=False)
		recip_ranks = defaultdict(list)
		for rule, label_idx in self.rules.items():
			rule_len = len(rule)
			all_pred_indices = sorted_weights[label_idx]
			for rule_predicate in rule:
				rule_predicate_idx = grnd_pred_mapping.index(rule_predicate)
				rank = torch.where(all_pred_indices == rule_predicate_idx)[0].item() + 1
				# print(rank)
				recip_ranks[rule_len].append(1/rank)

		mean_recip_rank = {f'rule len {rule_len}': sum(correct)/len(correct) for rule_len, correct in recip_ranks.items()}
		return mean_recip_rank

	def generate_comb_matrix(self, search_indices=None, len_shift=0.2):
		combinatorial_sub_matrices = []
		if search_indices is None:
			predicate_indices = range(self.unique_predicates)
		else:
			predicate_indices = search_indices

		start_len = 1 if self.variable_len else self.max_rule_len
		for rule_len in range(start_len, self.max_rule_len + 1):
			if self.sparse_combs:
				combinatorial_sub_matrix = self._generate_comb_matrix_sparse(predicate_indices, rule_len, len_shift)
			else:
				combinatorial_sub_matrix = self._generate_comb_matrix(predicate_indices, rule_len, len_shift)
			combinatorial_sub_matrices.append(combinatorial_sub_matrix)

		combinatorial_matrix = torch.cat(combinatorial_sub_matrices, dim=1)
		return combinatorial_matrix

	def _generate_comb_matrix(self, predicate_indices, rule_len, len_shift):
		# print(rule_len)
		len_combs = list(combinations(predicate_indices, rule_len))
		# if we run into mem issues, use torch.sparse, since this is just a mask
		combinatorial_sub_matrix = torch.zeros(self.unique_predicates, len(len_combs))
		for comb_idx, comb in enumerate(len_combs):
			# for multiple predicates, we want the evalution to add to 1
			# for longer rules since if a single predicate is fulfilled only a single predicate rules needs to be predicted
			# thus we slightly downweight the mask based on the rule length so if both predicates are present the longer rules is more favorable
			# if len_shift = .2 and f :- p1 ^ p2, if comb a:- p1 weight is 1/(1 + .1) ~.8 then chose comb b:- p1 ^ p2 weight 1(2 + .1) ^ 1(2 + .1) ~ .9
			combinatorial_sub_matrix[comb, comb_idx] = 1 /(rule_len + len_shift)
		return combinatorial_sub_matrix

	def _generate_comb_matrix_sparse(self, predicate_indices, rule_len, len_shift):
		len_combs = list(combinations(predicate_indices, rule_len))
		# if we run into mem issues, use torch.sparse, since this is just a mask
		coo_indices = [[],[]]
		coo_values = []
		for comb_idx, comb in enumerate(len_combs):
			for pred_idx in comb:
				coo_indices[0].append(pred_idx)
				coo_indices[1].append(comb_idx)
				coo_values.append(1 /(rule_len + len_shift))
		sparse_combinatorial_sub_matrix = torch.sparse_coo_tensor(coo_indices, coo_values, size=(self.unique_predicates, len(len_combs)))
		return sparse_combinatorial_sub_matrix

	def _get_comb_matrix(self, device):
		if self.label_comb_matrix.device != device:
			self.label_comb_matrix = self.label_comb_matrix.to(device)
		return self.label_comb_matrix

	def _sparse_hadamard_reshape(self, batch_sparse, dense):
		pred_idx, all_label_idx = batch_sparse._indices()
		
		# need to insert an index for each label to create # (unique_preds, num_labels, combs_per_label)
		label_idx = torch.arange(0, self.num_labels)
		repeats = int(len(all_label_idx)/len(label_idx))
		label_idx = torch.repeat_interleave(label_idx, repeats=repeats).to(all_label_idx.device)

		# now the indices for the concatendated label combinations need to be reindexed by each label now
		comb_idx = all_label_idx % repeats
		new_indices = torch.stack([pred_idx, label_idx, comb_idx])

		values = batch_sparse._values()
		new_values = values * dense[all_label_idx]

		device = torch.device(dense.device)
		shape = (self.unique_predicates, self.num_labels, repeats)
		new_sparse = torch.sparse_coo_tensor(new_indices, new_values, size=shape, device=device, requires_grad=True)
		return new_sparse

	def forward(self, x):
		batch_size, num_predicates = x.shape  # (batch_size, unique_preds)
		assert num_predicates == self.unique_predicates

		select = self.select.reshape(self.num_labels, self.combs_per_label)
		select = select.softmax(dim=1)  # softmax across the combination weights, to choose 1
		select = select.reshape(self.num_labels * self.combs_per_label) 

		if self.sparse_combs:
			comb_matrix = self._get_comb_matrix(x.device)
			weighted_combs = self._sparse_hadamard_reshape(comb_matrix, select)
			# at this point the largest memory bottlneck, the rule length combinations has been aleviated, go back to dense
			weighted_sum_combs = torch.sparse.sum(weighted_combs, dim=-1).to_dense()
		else:
			comb_matrix = self._get_comb_matrix(x.device)
			weighted_combs = comb_matrix * select
			weighted_sum_combs = weighted_combs.reshape(num_predicates, self.num_labels, self.combs_per_label)
			# all these values are the same for every batch, only create the batch after we sum to reduce memory, sparse is not really required
			weighted_sum_combs = weighted_sum_combs.sum(dim=-1)
		weighted_sum_combs = weighted_sum_combs.repeat(batch_size, 1, 1) # (batch_size, unique_preds, combs_per_label * num_labels)
		
		# evaluate the combinations against the active predicates
		comb_eval = torch.bmm(x.unsqueeze(1), weighted_sum_combs).squeeze(1)  # (batch_size, num_labels)

		# combination values should add up to [0, 1]
		test_min, test_max = comb_eval.min(), comb_eval.max()
		assert test_min + ETA >= 0 and test_max - ETA <= 1, (test_min, test_max)
		return comb_eval

	def build_var_len_rules(self, k=5, include_weights=False):
		combination_weights = self.select.reshape(self.num_labels, self.combs_per_label)

		# could add top k later
		selected_label_combination = torch.flip(torch.argsort(combination_weights, dim=1), dims=[1])[:, :k]
		if include_weights:
			selected_weights = torch.softmax(combination_weights, dim=1)
		if self.sparse_combs:
			label_combs = self.label_comb_matrix.to_dense()
		else:
			label_combs = self.label_comb_matrix
		all_label_combinations = label_combs.reshape(self.unique_predicates, self.num_labels, self.combs_per_label)
		all_label_combinations = all_label_combinations.permute(1, 2, 0)  # (self.num_labels, self.combs_per_label, self.unique_predicates)
		rules = []
		for rule_num, (comb_indices, combinations) in enumerate(zip(selected_label_combination, all_label_combinations)):
			top_rules = []
			for comb_idx in comb_indices:
				predicates = combinations[comb_idx]
				rule = tuple({self.grnd_pred_mapping[pred_idx] for pred_idx in torch.where(predicates > 0)[0]})
				if include_weights:
					weight = selected_weights[rule_num][comb_idx].item()
					top_rules.append((weight, rule))
				else:
					top_rules.append(rule)
			rules.append(top_rules)
		return rules

class ReasoningNetwork(nn.Module):
	def __init__(self, atomic_events, relations, num_comp_events):
		super().__init__()
		self.atomic_events = atomic_events
		self.relations = relations
		self.num_comp_events = num_comp_events
		self.unique_predicates = len(self.atomic_events)**2 * len(self.relations)
		self.projection = nn.Linear(self.unique_predicates, self.num_comp_events, bias=False)
		self.rule_weights = self.projection.weight  # note weights stored as (output dim x input dim)

	def forward(self, x):
		if self.training:
			dropout_mask = (torch.rand_like(x) < .5)
			x = x * dropout_mask
		x = self.projection(x)
		return x

	def build_top_k_rules(self, k=1, add_cons_facts=True, n=2):
		"""
		for each class index, return a conjunction of temporal predicates
		"""
		fact_mapping = actions_order_dataset(n=2, unique=False)

		rule_weights = self.rule_weights
		fact_indices = torch.argsort(rule_weights, dim=1)[:, -k:]
		fact_weights = torch.sort(rule_weights, dim=1)[0][:, -k:]

		rules = []
		for output_rules, weights in zip(fact_indices, fact_weights):
			facts = []
			for fact_idx in output_rules:
				fact = fact_mapping[fact_idx]
				if fact not in facts:
					facts.append(fact)
					
				if add_cons_facts:
					# cons_fact = consistent_action(fact)
					cons_fact = reverse(fact)
					if cons_fact not in facts:
						facts.append(cons_fact)
			if n > 2:
				two_pred_rules = []
				for obj_fact, obj_w in zip(facts, weights):
					for subj_fact, subj_w in zip(facts, weights):
						two_pred_rule = build_two_pred(obj_fact, subj_fact)
						if two_pred_rule is not None:
							two_pred_rules.append((obj_w * subj_w, two_pred_rule))
				# print(two_pred_rules)
				if len(two_pred_rules):
					two_pred_rules = sorted(two_pred_rules, reverse=True)
					_, two_pred_rules = zip(*two_pred_rules)
				rules.append(two_pred_rules[:k])
			else:
				rules.append(facts)
		return rules
		# we should also support getting the consistent facts as well

	def build_top_k_rules_dynamic_(self, stop_thresh=0.1):
		# ie the flattened M_R matrix, r_i(e_1, e_2)
		fact_mapping = actions_order_dataset(n=2, unique=False)

		rule_weights = self.rule_weights

		# sort then move to descending order
		fact_indices = torch.flip(torch.argsort(rule_weights, dim=1), [1])
		fact_weights = torch.flip(torch.sort(rule_weights, dim=1)[0], [1])

		rules = []
		MAX_RULE_LEN=10
		for output_rules, weights in zip(fact_indices, fact_weights):
			# rules will contain at least one predicate
			facts = [fact_mapping[output_rules[0]]]
			weight_delta = weights[:MAX_RULE_LEN] - weights[1:MAX_RULE_LEN+1]
			norm_weight_delta = weight_delta / weights[1:MAX_RULE_LEN+1]
			for pred_idx, change in zip(output_rules, norm_weight_delta):
				if change < stop_thresh:
					facts.append(fact_mapping[pred_idx])
				else:
					break
			rules.append(facts)
		return rules

	def build_top_k_rules_dynamic(self, k=10, append_thresh=0.01):
				# ie the flattened M_R matrix, r_i(e_1, e_2)
		fact_mapping = actions_order_dataset(n=2, unique=False)

		rule_weights = self.rule_weights

		# sort then move to descending order
		fact_indices = torch.flip(torch.argsort(rule_weights, dim=1), [1])
		fact_weights = torch.flip(torch.sort(rule_weights, dim=1)[0], [1])

		pred_rules = []
		for relation_idx, weight in zip(fact_indices, fact_weights):

			relations = [fact_mapping[i] for i in relation_idx[:k]]
			scores = weight[:k]
			scores = scores/scores.sum()

			unique_rel = []
			unique_scores = []
			for relation, score in zip(relations, scores):
				if relation not in unique_rel and reverse(relation) not in unique_rel:
					unique_rel.append(relation)
					unique_scores.append(score)

			# rule_len = self.config.gen_len
			rule_len = 3
			if rule_len > 1:
				# combinatorial search over the k relations
				relations = list(zip(unique_rel, unique_scores))
				num_combs = min(len(relations), rule_len)
				all_combinations = list(combinations(relations, num_combs))
				rules = []
				for comb in all_combinations:
					rule = []
					prev_score = comb[0][1]
					for relation in comb[:num_combs]:
						predicate, score = relation
						if prev_score - score < (append_thresh):
							rule.append(predicate)
							prev_score = score
						else:
							break
					rules.append(tuple(rule))
			else:
				rules = [((relation),) for relation in unique_rel]

			rules = rules[:k]
			pred_rules.append(rules)
		return pred_rules

	def build_rule_tree(self, max_purity=0.8):
		raise NotImplementedError