import torch
import torch.nn as nn

from collections import Counter

from temporal.models.dense import MLP

import logging

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

ETA = 1e-5

class RelationNetwork(nn.Module):
	def __init__(self, time_dim_quant, num_rela, agg_mode='agg_before', agg_type='sum', hidden_layers=[]):
		"""
		agg_before: aggregate the time series before the relation prediciton
		agg_after:
		agg_other_ts: for the current timeline, compute the relations, and add in relations for all other timelines aggregated
		"""
		super().__init__()
		self.time_dim_quant = time_dim_quant
		self.num_rela = num_rela
		self.agg_mode = agg_mode
		self.agg_type = agg_type  # maybe possible to pass the torch function directly?

		self.network = self.init_model(hidden_layers)
		
		# to run sigmoid on aggregated ts predictions in [0, inf], the 0 predicitons evaluate to 0.5, so we shift those back
		self.mask_fill = torch.nn.Parameter(torch.tensor([-5.0], requires_grad=True))

	def init_model(self, hidden_layers):
		# dimensions of subject and object quantized time series
		layer_sizes = hidden_layers.copy()
		# layer_sizes.insert(0, self.time_dim_quant ** 2)
		layer_sizes.insert(0, 4)
		# dimensions of temporal relationships
		layer_sizes.append(self.num_rela)
		
		return MLP(layer_sizes)


	# possible that funcitons are consistent and they can be passed in directly instead of agg_type
	def agg_ts(self, ts):
		# aggrgates over timeline 
		# (Batch, Timeline, Atomic Event, Timeseries) -> (Batch, Atomic Event, Timeseries) 
		if self.agg_type == 'sum':
			return torch.sum(ts, 1)
		else:
			raise ValueError(f'Unknown agg type {self.agg_type}')

	# take aggregated time series and output the relation predictions
	def evaluate_agg(self, ts):
		# pairwise atomic event relations -> (Batch, Atomic Event^2, Timeseries * 2)
		n_events = ts.shape[1]
		subjects = torch.repeat_interleave(ts, repeats=n_events, dim=1)
		objects = ts.repeat(1, n_events, 1)
		
		model_input = torch.cat((subjects, objects, subjects - objects), dim=-1)

		# # drop the batch, and predict relations across all events for all batches
		bs, pair_events, t_dim = model_input.shape
		model_input = model_input.reshape(bs * pair_events, t_dim)
		
		# maybe sigmoid here? the summed values can be very high <- tried and it seems worse off
		ts = self.network(model_input)
		# normalize, two events can have multiple occurrences before and/or after each other, hence sigmoid
		ts = torch.sigmoid(ts)

		# # after is 1 - before, but it can occur before and after
		# ts = torch.cat((ts, 1 - ts[:, [0]]), dim=1)

		logger.info(f'relation distribution {Counter(ts.argmax(dim=1).tolist())}')
		
		# flatten event relation predicates for the reasoning network downstream
		# concatenates along the rows as [(s1, r1, o1), (s1, r2, o1), (s1, r1, o2), ...]
		ts = ts.reshape(bs, pair_events * self.num_rela)

		return ts

	def evaluate_agg_outer_diff(self, ts):
		# pairwise atomic event relations -> (Batch, Atomic Event^2, Timeseries * 2)
		bs, n_events, quant_dim = ts.shape
		pair_temp = quant_dim ** 2
		pair_events = n_events ** 2
		
		n_events = ts.shape[1]
		subjects = torch.repeat_interleave(ts, repeats=n_events, dim=1)
		objects = ts.repeat(1, n_events, 1)

		subj_flat = subjects.reshape(-1, quant_dim).unsqueeze(-1)
		obj_flat = objects.reshape(-1, quant_dim).unsqueeze(-2)

		outer_diff = subj_flat - obj_flat
		model_input = outer_diff.reshape(-1, pair_temp)
		
		ts = self.network(model_input)
		# normalize, two events can have multiple occurrences before and/or after each other, hence sigmoid
		ts = torch.sigmoid(ts)

		# flatten event relation predicates for the reasoning network downstream
		# concatenates along the rows as [(s1, r1, o1), (s1, r2, o1), (s1, r1, o2), ...]
		ts = ts.reshape(bs, pair_events * self.num_rela)

		return ts

	def evaluate_pre_agg(self, ts):
		raise NotImplementedError

	def agg_before(self, ts):
		ts = self.agg_ts(ts)
		ts = self.evaluate_agg(ts)
		ts = self.mask_sigmoid(ts)
		return ts

	# the min ts values will be 0, so sigmoid will still give a relatively high 0.5 weight
	def mask_sigmoid(self, ts):
		# fill can be potentially learnable through the data
		zeros = (ts < ETA) * self.mask_fill
		ts = ts + zeros
		ts = torch.sigmoid(ts)
		return ts

	def agg_after(self, ts):
		ts = self.evaluate_pre_agg(ts)
		ts = self.agg_ts(ts)
		ts = self.mask_sigmoid(ts)
		return ts

	def agg_other_ts(self, ts):
		k_ts = ts.shape[1]
		n_events = ts.shape[2]
		for ts_idx in range(k_ts):
			obj_ts = ts[:, ts_idx, :, :]
			subj_ts = ts[:, -ts_idx, :, :]
		return ts

	def forward(self, ts):
		if self.agg_mode == 'agg_before':
			ts = self.agg_before(ts)
		elif self.agg_mode == 'agg_after':
			ts = self.agg_after(ts)
		elif self.agg_mode == 'agg_other_ts':
			ts = self.agg_other_ts(ts)
		else:
			raise ValueError(f'Unknown agg mode {self.agg_mode}')
		return ts

class RelationNetworkLogic(RelationNetwork):
	
	def __init__(self, time_dim_quant, num_rela, agg_mode, agg_type):
		super().__init__(time_dim_quant, num_rela, agg_mode=agg_mode, agg_type=agg_type)

	def init_model(self, hidden_layers):
		return TemporalLogicNetwork()

	def evaluate_agg(self, ts):
		bs, n_events, t_dim = ts.shape
		ts = ts.reshape(bs * n_events, t_dim)

		# when multiplying the time index, we can take the max value as the end of the durational event
		event_end = ts.max(dim=1).values

		# for the start we take the min, but we mask out the zero entries by adding the max then remove the mask
		min_mask = ts.max() * (ts < ETA)
		event_start = (ts + min_mask).min(dim=1).values - min_mask.min(dim=1).values

		interval_ts = torch.stack((event_start, event_end), dim=1)
		interval_ts = interval_ts.reshape(bs, n_events, 2)

		# pairwise atomic event relations -> (Batch, Atomic Event^2, Timeseries * 2)
		repeats = n_events
		subjects = torch.repeat_interleave(interval_ts, repeats=repeats, dim=1)
		objects = interval_ts.repeat(1, repeats, 1)
		
		# when comparing the atomic event against itself in the pairwise comparison, we will always get a during relation, mask this
		self_comp_mask = torch.ones_like(objects[0])
		self_mask_index = torch.arange(0, self_comp_mask.size(0), repeats) + torch.arange(0, repeats)
		self_comp_mask[self_mask_index, :] = 0
		objects = objects * self_comp_mask

		model_input = torch.cat((subjects, objects), dim=-1)

		# # drop the batch, and predict relations across all events for all batches
		bs, pair_events, t_dim = model_input.shape
		model_input = model_input.reshape(bs * pair_events, t_dim)
		
		ts = self.network(model_input)
		# logger.info(f'relation distribution {Counter(ts.argmax(dim=1).tolist())}')
		
		# flatten event relation predicates for the reasoning network downstream
		# concatenates along the rows as [(s1, r1, o1), (s1, r2, o1), (s1, r1, o2), ...]
		ts = ts.reshape(bs, pair_events * self.num_rela)

		return ts

	# here we run the pairwise comparison across events AND timelines
	def evaluate_pre_agg(self, ts):
		bs, timelines, n_events, t_dim = ts.shape
		ts = ts.reshape(bs * timelines * n_events, t_dim)

		# when multiplying the time index, we can take the max value as the end of the durational event
		event_end = ts.max(dim=1).values

		# for the start we take the min, but we mask out the zero entries by adding the max + eps then remove the mask
		# we do max + eps, since if the event is instantaenous we want to start to also be that value (max), and if we filter by max we won't have it
		min_mask = (ts.max() + ETA) * (ts < ETA)
		event_start = (ts + min_mask).min(dim=1).values - min_mask.min(dim=1).values

		interval_ts = torch.stack((event_start, event_end), dim=1)
		interval_ts = interval_ts.reshape(bs, timelines * n_events, 2)

		# pairwise atomic event intervals -> (Batch, (timelines * n_events)^2, 2)
		repeats = timelines * n_events
		subjects = torch.repeat_interleave(interval_ts, repeats=repeats, dim=1)
		objects = interval_ts.repeat(1, repeats, 1)

		# when comparing the atomic event against itself in the pairwise comparison, we will always get a during relation, mask this
		self_comp_mask = torch.ones_like(objects[0])
		self_mask_index = torch.arange(0, self_comp_mask.size(0), repeats) + torch.arange(0, repeats)
		self_comp_mask[self_mask_index, :] = 0
		objects = objects * self_comp_mask

		model_input = torch.cat((subjects, objects), dim=-1)

		# # drop the batch, and predict relations across all events for all batches
		bs, pair_events, t_dim = model_input.shape
		model_input = model_input.reshape(bs * pair_events, t_dim)
		
		ts = self.network(model_input)
		# logger.info(f'relation distribution {Counter(ts.argmax(dim=1).tolist())}')
		
		# factors into bs, timelines, n_events, timelines, n_events, self.num_rela (output size)
		ts = ts.reshape(bs, timelines, n_events, timelines, n_events, self.num_rela)
		# bs, timelines, timelines, n_events, n_events, self.num_rela (output size)
		ts = ts.permute(0, 1, 3, 2, 4, 5)
		ts = ts.reshape(bs, timelines**2, n_events**2 * self.num_rela)
		# ts = ts.reshape(bs, timelines ** 2, -1)

		return ts

class TemporalLogicNetwork(nn.Module):

	def __init__(self):
		super().__init__()
		self.shift = nn.Parameter(torch.zeros(3, requires_grad=True))
		self.scale = nn.Parameter(torch.ones(3, requires_grad=True))

	def forward(self, intervals):
		"""
		- If e1 end < e2 start then before
		- if e2 end < e1 start then after
		- if e1 start < e2 end and e2 starts < e1 end
		- else none
		"""

		subject_start = intervals[:, 0]
		subject_end = intervals[:, 1]
		object_start = intervals[:, 2]
		object_end = intervals[:, 3]
		
		# as a design choice, we represent the output variables as positive when true, and init the parameters accordingly
		before = object_start - subject_end
		after = subject_start - object_end
		
		# for during satisfy both conditions using weak Łukasiewicz logic: and -> min, or -> max
		# condition 1: subject interval must start before object interval ends
		cond_1 = (object_end - subject_start)
		# condition 2: object interval must start before the subject interval ends
		cond_2 = (subject_end - object_start)
		during = torch.stack((cond_1, cond_2), dim=1).min(dim=1).values

		unnorm_outputs = torch.stack((before, during, after), dim=1)
		outputs = torch.softmax((unnorm_outputs - self.shift)/self.scale, dim=1)

		# if there is no interval, ie the start and end is 0, 0 for either subject or object, we will mask it out here
		interval_size = torch.stack((subject_end - subject_start, object_end - object_start), dim=1).min(dim=1).values

		# if the interval is 0, then all values of the corresponding prediction will go to 0, else they will be the prediction
		outputs = torch.min(outputs, interval_size.repeat(outputs.shape[-1], 1).T)

		assert (outputs <= 1.0).all().item() and (outputs >= 0.0).all().item(), (outputs.max(), outputs.min())

		return outputs