import torch
import torch.nn as nn

from temporal.models.neuraltlp import TemporalRelationNetwork
from temporal.inference import ACTION_CLASSES, ORDERING
from temporal.dataset import OUTPUT_DIM
from datasets.cater.enums import MAX_FRAMES

class TemporalMAP(TemporalRelationNetwork):

	def forward(self, ts):
		# ts timeseries is timelines x atomic events x time_dim
		# if self.quantize:
		ts = self.temp_quant(ts)
		ts = self.rel_network(ts)
		return ts

class TemporalLSTM(nn.Module):
	def __init__(self, atomic_events=ACTION_CLASSES, relations=ORDERING, num_comp_events=OUTPUT_DIM, time_dim=MAX_FRAMES, hidden_size=128, num_layers=1, bidirectional=False, attention_dim=0):
		super().__init__()
		self.num_atomic_events = len(atomic_events)
		self.num_rela = len(relations)
		self.time_dim = time_dim
		self.num_comp_events = num_comp_events
		self.network = nn.LSTM(self.num_atomic_events, hidden_size, num_layers=num_layers, bidirectional=bidirectional)

		input_size = hidden_size * 2 if bidirectional else hidden_size
		self.linear = nn.Linear(input_size, self.num_comp_events)

		self.attention_dim = attention_dim
		# save the hidden state if the current input is meaningful
		if self.attention_dim > 0:
			self.proj = nn.Parameter(torch.rand(self.num_atomic_events + hidden_size, self.attention_dim, requires_grad=True))
			self.dot = nn.Parameter(torch.rand(self.attention_dim, 1, requires_grad=True))


	def forward(self, ts):
		# print(ts.shape)
		batch_size, num_events, ts_len = ts.shape
		ts_reformat = ts.permute(2, 0, 1)
		self.network.flatten_parameters()
		output, hidden = self.network(ts_reformat)
		if self.attention_dim > 0:
			# concatenate the hidden dimension and model inputs
			cat = torch.cat((output, ts_reformat), dim=2).reshape(ts_len * batch_size, -1)

			# learnable projection + nonlinear
			proj = torch.tanh(torch.mm(cat, self.proj))

			# compute attention
			attn = torch.mm(proj, self.dot)

			# reformat attention to the hidden batch output shape (ts_len, batch_size, hidden_dim)
			batch_attn = attn.reshape(ts_len, batch_size, 1).repeat(1, 1, output.shape[-1])

			# weight the hidden states and then marginalize over them
			linear_inputs = (batch_attn * output).sum(dim=0)
		else:
			linear_inputs = output[-1]
		output = self.linear(linear_inputs)
		return output