import torch.nn as nn

from temporal.models.quantize import TemporalQuantization
from temporal.models.relation import RelationNetworkLogic
from temporal.models.reasoning import ReasoningNetwork, ReasoningNetworkVariable

from temporal.inference import ACTION_CLASSES, ORDERING
from temporal.dataset import OUTPUT_DIM
from datasets.cater.enums import MAX_FRAMES

import logging

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

class TemporalRelationNetwork(nn.Module):

	def __init__(self, atomic_events=ACTION_CLASSES, relations=ORDERING, num_comp_events=OUTPUT_DIM, time_dim=MAX_FRAMES, time_dim_quant=4, time_mult=True, 
		agg_mode='agg_before', agg_type='sum', conv_fill=None):
		super().__init__()
		self.atomic_events = atomic_events
		self.relations = relations
		self.num_atomic_events = len(atomic_events)
		self.num_rela = len(relations)

		self.num_comp_events = num_comp_events
		self.time_dim = time_dim
		self.time_dim_quant = time_dim_quant
		self.time_mult = time_mult
		self.agg_mode = agg_mode
		self.agg_type = agg_type
		self.conv_fill = conv_fill

		self.init_temp_quant()
		self.init_rel_network()
		self.init_reasoning_network()

	def init_temp_quant(self):
		# compresses the temporal dimension
		self.temp_quant = TemporalQuantization(self.num_atomic_events, time_dim=self.time_dim, time_dim_quant=self.time_dim_quant, time_mult=self.time_mult, fill=self.conv_fill)

	def init_rel_network(self):
		# predicts the pairwise relations between events across timelines
		self.rel_network = RelationNetworkLogic(self.time_dim_quant, self.num_rela, agg_mode=self.agg_mode, agg_type=self.agg_type)
		
	def init_reasoning_network(self):
		# maps the logical relations to the provided labels
		self.reasoning_network = ReasoningNetwork(self.atomic_events, self.relations, self.num_comp_events)

	def forward(self, ts):
		# ts timeseries is timelines x atomic events x time_dim
		# if self.quantize:
		ts = self.temp_quant(ts)
		ts = self.rel_network(ts)
		labels = self.reasoning_network(ts)
		
		if self.training:
			return labels, ts
		else:
			return labels

class TemporalRelationNetworkVariable(TemporalRelationNetwork):
	def __init__(self, atomic_events=ACTION_CLASSES, relations=ORDERING, num_comp_events=OUTPUT_DIM, time_dim=MAX_FRAMES, time_dim_quant=4, time_mult=True, 
		agg_mode='agg_before', agg_type='sum', conv_fill=None, max_rule_len=1, attn_cand_beam=10, variable_len=True, max_rules_beam=100):
		
		self.max_rule_len = max_rule_len
		self.attn_cand_beam = attn_cand_beam
		self.variable_len = variable_len
		self.max_rules_beam = max_rules_beam

		super().__init__(atomic_events=atomic_events, relations=relations, num_comp_events=num_comp_events, time_dim=time_dim, time_dim_quant=time_dim_quant, time_mult=time_mult, 
		agg_mode=agg_mode, agg_type=agg_type, conv_fill=conv_fill)

		self.init_variable_reasoning()
		self.rule_search = False
		self.relations = False

	def init_variable_reasoning(self):
		# maps the logical relations to the provided labels
		self.reasoning_network_var = ReasoningNetworkVariable(self.atomic_events, self.relations, self.num_comp_events, self.max_rule_len, \
			attn_cand_beam=self.attn_cand_beam, variable_len=self.variable_len, max_rules_beam=self.max_rules_beam)

	def forward(self, ts):
		ts = self.temp_quant(ts)
		ts = self.rel_network(ts)
		if self.relations:
			return ts
			
		if self.rule_search:
			labels = self.reasoning_network_var(ts)
		else:
			labels = self.reasoning_network(ts)
		
		if self.training:
			return labels, ts
		else:
			return labels