from typing import DefaultDict
import torch
from torch.utils.data import DataLoader
import torch.nn as nn

import os
from pathlib import Path

import json
import numpy as np
from math import floor
from tqdm import tqdm
from copy import deepcopy
from collections import defaultdict, deque, Counter
from itertools import combinations

from datasets.cater.enums import ACTION_CLASSES, ORDERING, actions_order_dataset, actions_order_mapping, reverse, check_rule_match
from datasets.cater.utils import rule_examples
from datasets.cater.generate import folder_format, get_path, generate_rules

from vision.visionutils import load_labels, MAX_FRAMES
from temporal.state_space import VisionLoader, VideoStateSpace, GroundTruthMovementLoader, VideoStateSpaceGT
# from temporal.model import *

from temporal.models.neuraltlp import TemporalRelationNetworkVariable, TemporalRelationNetwork
from temporal.models.baselines import TemporalMAP, TemporalLSTM
from temporal.models.dense import TemporalModel

from temporal.dataset import *
from temporal.map_voc import compute_multiple_aps, compute_mAP

import logging

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

ETA = 1e-5

class TemporalTraining:

	def __init__(self, config):
		self.exp_name = config.exp_name
		self.data_dir = config.cater_path
		self.model_path = config.model_path
		self.save_path = os.path.join(self.model_path, f'{self.exp_name}.pth')
		Path(self.model_path).mkdir(parents=True, exist_ok=True)

		self.epochs = config.epochs
		self.val_after_n_epochs = config.val_interval
		self.proj_interval = config.proj_interval
		self.batch_size = config.batch_size
		self.lr = config.lr
		self.samples = None if config.samples < 0 else config.samples
		self.train_samples = None if config.train_samples < 0 else config.train_samples
		self.num_workers = config.loader_workers
		self.device = config.device
		self.tqdm = config.tqdm
		
		# we project reasoning mapping to [0, 1] thus our predictions are already normalized
		if self.proj_interval > 0:
			self.loss_func = nn.BCELoss()
		else:
			self.loss_func = nn.BCEWithLogitsLoss()
		self.gt_atomic = config.gt

		self.config = config

	def compute_loss(self, pred, label, model):
		return self.loss_func(pred, label)

	def post_step(self, model, batch_num):
		pass

	def get_optimizer(self, model):
		# if we have frozen some layers, don't compute those grads
		params = filter(lambda p: p.requires_grad, model.parameters())
		# if len(list(params)) > 0:
		optimizer = torch.optim.Adam(params, lr=self.lr)
		return optimizer

	def train_model(self, model, train_loader, val_loader, cv_loader=None, save=True):
		model = model.to(self.device)

		optimizer = self.get_optimizer(model)
		
		logger.info(f'Training on device: {self.device}')
		iter_wrapper = (lambda x: tqdm(x, total=len(train_loader))) if self.tqdm else (lambda x: x)

		save_model = deepcopy(model)
		best_mAP = 0.0
		batch_num = 1

		for epoch in range(1, self.epochs + 1):
			model.train()
			losses = []
			for sample, label in iter_wrapper(train_loader):
				sample, label = sample.to(self.device), label.to(self.device)
				optimizer.zero_grad()
				pred = model(sample)
				loss = self.compute_loss(pred, label, model)
				losses.append(loss.item())
				loss.backward()
				optimizer.step()
				self.post_step(model, batch_num)
				batch_num += 1
				# print(model.temp_quant.conv_scalar)
			logger.info(f'Epoch {epoch} mean batch loss = {np.mean(losses)}')
			if epoch % self.val_after_n_epochs == 0:
				# use a cross validation set for early stopping
				if cv_loader is not None:
					mAP = self.evaluate_model_voc(model, cv_loader)
					if mAP >= best_mAP:
						#TODO: have the saved model in the cpu instead of taking up GPU memory, if so only option is write to disk
						logger.info('copying best model')
						save_model = deepcopy(model)
						best_mAP = mAP
						if save:
							self.save_model(model)
							logger.info('saving model to disk')
				else:
					self.evaluate_model_voc(model, val_loader)
			# print(model.temp_quant.conv_scalar)
			# print(model.rel_network.mask_fill, model.rel_network.network.shift, model.rel_network.network.scale)
		# if 0 < self.epochs < self.val_after_n_epochs:
		logger.info('evaluating test data')
		# if save:
		self.evaluate_model_voc(save_model, val_loader)
		# else:
		# 	self.evaluate_model_voc(model, val_loader)
		return model, optimizer

	def save_model(self, model):
		torch.save(model.state_dict(), self.save_path)

	def evaluate_model_voc(self, model, val_loader, print_score=True):
		model.eval()
		all_predictions = []
		all_labels = []
		with torch.no_grad():
			for sample, label in val_loader:
				sample, label = sample.to(self.device), label.to(self.device)
				if self.proj_interval > 0:
					# using BCELoss, so make sure values are [0, 1]
					pred = model(sample).clamp(min=ETA, max=1.0)
				else:
					pred = torch.sigmoid(model(sample))
				all_predictions.append(pred.cpu().numpy())
				all_labels.append(label.cpu().numpy())
		all_predictions = np.vstack(all_predictions)
		all_labels = np.vstack(all_labels)
		ap = compute_multiple_aps(all_labels, all_predictions)
		mAP = compute_mAP(ap)
		if print_score:
			logger.info(f'validation mAP {mAP}')
		return mAP

	def get_model(self):
		return TemporalModel(INPUT_DIM, out_dim=OUTPUT_DIM)

	def load_model(self):
		model = self.get_model()
		model_dict = model.state_dict()
		
		logger.info(f'Loading model from: {self.save_path}')
		checkpoint = torch.load(self.save_path)
		# select parameters are tied to the beam width length, so if only don't those match, ie testing widths, just load the current width
		if checkpoint['reasoning_network_var.select'].shape != model_dict['reasoning_network_var.select'].shape:
			checkpoint.update({'reasoning_network_var.select': model_dict['reasoning_network_var.select']})
		model.load_state_dict(checkpoint)
		return model

class TemporalRelationTraining(TemporalTraining):

	def get_state_space(self, loader):
		if self.gt_atomic:
			state_space = VideoStateSpaceGT(loader)
		else:
			state_space = VideoStateSpace(loader)
		state_space.infer_classes(composite=False)
		return state_space

	def get_model(self):
		time_dim_quant = self.config.time_dim_quant
		model = TemporalRelationNetwork(atomic_events=ACTION_CLASSES, relations=ORDERING, time_dim=MAX_FRAMES, \
			time_dim_quant=time_dim_quant, time_mult=True, conv_fill=1.0, agg_mode=self.config.agg_mode, agg_type=self.config.agg_type)

		if self.config.freeze_rela:
			self.fix_relations(model)
		if self.config.freeze_mapping:
			self.fix_mapping(model)
		return model

	def fix_relations(self, model):
		if model.temp_quant.quantize:
			model.temp_quant.conv1d.weight.requires_grad = False
			model.temp_quant.conv_scalar.requires_grad = False
		model.rel_network.mask_fill.requires_grad = False
		model.rel_network.network.shift.requires_grad = False
		model.rel_network.network.scale.requires_grad = False

	def fix_mapping(self, model):
		gt_proj = torch.zeros_like(model.reasoning_network.rule_weights.data)
		gt_indices = actions_order_mapping()
		for idx, (class_idx, comp_event_idx) in enumerate(gt_indices):
			# set those corresponding weights, s.t. when those comp actions are detected, then we predict this class
			gt_proj[idx, [class_idx, comp_event_idx]] = 1.0
			# gt_proj[idx, comp_event_idx] = 1.0
		model.reasoning_network.rule_weights.data = gt_proj
		model.reasoning_network.rule_weights.requires_grad = False

	def get_dataset(self, state_space, outputs):
		return TemporalFeatures(state_space, outputs, mode='timelines')

	def get_loader(self, split):
		files, actions = load_labels(self.data_dir, split=split, folder='actions_order_uniq')
		if self.gt_atomic:
			loader = GroundTruthMovementLoader(self.data_dir, split=split)
		else:
			loader = VisionLoader(self.data_dir, split=split)

		samples = self.samples
		if split == 'train_subsetT' and self.train_samples is not None:
			samples = self.train_samples
		
		loader.load_all(samples=samples, folder='actions_order_uniq')
		state_space = self.get_state_space(loader)
		
		dataset = self.get_dataset(state_space, (files[:samples], actions[:samples]))

		feature_loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
		return feature_loader

	def train(self):
		if self.config.load_model:
			model = self.load_model()
		else:
			model = self.get_model()
		
		pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
		logger.info(f'num trainable parameters: {pytorch_total_params}')
		
		train_loader = self.get_loader('train_subsetT')
		cv_loader = self.get_loader('train_subsetV')
		val_loader = self.get_loader('val')
		# model, optimizer = self.train_model(model, cv_loader, cv_loader, cv_loader=cv_loader)
		model, optimizer = self.train_model(model, train_loader, val_loader, cv_loader=cv_loader)

	def post_step(self, model, batch_num):
		# projected SGD, force reasoning weights between [0, 1]
		if self.proj_interval > 0 and batch_num % self.proj_interval == 0:
			reason_weights = model.reasoning_network.rule_weights.data
			reason_weights = reason_weights.clamp(0 + ETA, 1 - ETA)
			model.reasoning_network.rule_weights.data = reason_weights

		conv_scale_weight = model.temp_quant.conv_scalar.data
		conv_scale_weight = conv_scale_weight.clamp(0 + ETA, 1)
		model.temp_quant.conv_scalar.data = conv_scale_weight
		
	def consistency_loss(self, ts):
		batch_size, comp_events = ts.shape
		ts = ts.reshape(batch_size, INPUT_DIM, INPUT_DIM, REL_DIM)
		
		# reverse the directions of the relation predictions, so we compare the opposite ones before <-> after
		ts_rev = torch.flip(ts, (-1,))
		# cons_loss = torch.norm(ts - ts.T, p='fro')

		idx = torch.arange(INPUT_DIM)
		sub_idx = torch.repeat_interleave(idx, repeats=INPUT_DIM)
		obj_idx = idx.repeat(INPUT_DIM)

		cons_losses = []
		for sample, sample_rev in zip(ts, ts_rev):
			# now compare the corresponding indices in addition to the reversed relations
			cons_err = torch.norm(sample[sub_idx, obj_idx] - sample_rev[obj_idx, sub_idx], p='fro')
			cons_losses.append(cons_err)

		cons_loss = torch.sum(torch.tensor(cons_losses)) * 1/len(cons_losses)
		return cons_loss

	def compute_loss(self, pred, label, model):
		pred, ts = pred
		if self.proj_interval > 0:
			# using BCELoss, so make sure values are [0, 1]
			pred = pred.clamp(min=ETA, max=1.0 - ETA)
		loss = self.loss_func(pred, label)
		# print(f'ce loss {loss.item()}')
		
		lambda_1 = self.config.l1_loss
		lambda_2 = self.config.cons_loss
		
		if lambda_1 > 0:
			reason_weights = model.reasoning_network.rule_weights
			l1_loss_norm =  torch.norm(reason_weights, p=1) / reason_weights.shape.numel()
			
			# print(f'l1 loss {l1_loss_norm.item()}', reason_weights.shape)
			loss += lambda_1 * l1_loss_norm

		if lambda_2 > 0:
			cons_loss = self.consistency_loss(ts)
			# print(f'consistency loss {cons_loss}')
			loss += lambda_2 * cons_loss

		return loss

	def evaluate_rule_accuracy(self, model, k=10):
		pred_rules = model.reasoning_network.build_top_k_rules(k=k)
		gt_rules = actions_order_dataset()
		
		correct = 0
		for gt_rule, pred_rule in zip(gt_rules, pred_rules):

			# if the gt rule is in the top k predicted rules for that label
			if gt_rule in pred_rule:
				correct += 1
		return correct/len(gt_rules)

	def evaluate_model_voc(self, model, val_loader):
		mAP = super().evaluate_model_voc(model, val_loader)
		for k in [1, 5, 10]:
			acc = self.evaluate_rule_accuracy(model, k=k)
			logger.info(f'Rule accuracy @{k}: {acc}')
		return mAP

class TemporalLSTMTraining(TemporalRelationTraining):
	def get_model(self):
		model = TemporalLSTM(atomic_events=ACTION_CLASSES, relations=ORDERING, time_dim=MAX_FRAMES, hidden_size=self.config.hidden_size, num_layers=self.config.num_layers, \
			bidirectional=self.config.bidirectional, attention_dim=self.config.attention_dim)
		return model

	def get_dataset(self, state_space, outputs):
		return TemporalFeatures(state_space, outputs, mode='2d')

	def compute_loss(self, pred, label, model):
		loss = self.loss_func(pred, label)
		return loss

	def evaluate_rule_accuracy(self, model, k=10):

		# based on the enumerate samples, create these samples in tensor form
		samples = rule_examples().to(self.device)
		
		rule_labels = model(samples)
		
		mapping = actions_order_mapping()
		enumerated_rules = actions_order_dataset(unique=False)

		# for the predicted labels for the enumerated rules, select the highest logit as the model's predicted rule for that label
		pred_rules = []
		rule_labels = rule_labels.argsort(dim=1)
		for pred_label in rule_labels:
			rules = []
			for rule_idx in pred_label[-k:]:
				rules.append(enumerated_rules[rule_idx])
			pred_rules.append(rules)

		pred_rules_rev = []
		for rule_map in mapping:
			rules = []
			for index in rule_map:
				rules.extend(pred_rules[index])
			pred_rules_rev.append(rules)
		
		correct = 0
		gt_rules = actions_order_dataset(unique=True)
		for gt_rule, pred_rule in zip(gt_rules, pred_rules_rev):

			# if the gt rule is in the top k predicted rules for that label
			if gt_rule in pred_rule:
				correct += 1
		return correct/len(gt_rules)

class TemporalMAPTraining(TemporalRelationTraining):

	def get_model(self):
		# standard network, but without the reasoning layer, use the rel_network and aggreagte the results
		time_dim_quant = self.config.time_dim_quant
		model = TemporalMAP(atomic_events=ACTION_CLASSES, relations=ORDERING, time_dim=MAX_FRAMES, \
			time_dim_quant=time_dim_quant, time_mult=True, conv_fill=1.0, agg_mode=self.config.agg_mode, agg_type=self.config.agg_type)
		
		self.fix_relations(model)
		return model

	def train(self, load=False, samples=None, gt_atomic=False):
		if self.config.load_model:
			model = self.load_model()
		else:
			model = self.get_model()
		
		pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
		logger.info(f'num trainable parameters: {pytorch_total_params}')
		
		train_loader = self.get_loader('train_subsetT')
		self.train_model(model, train_loader)

	def train_model(self, model, train_loader):
		model = model.to(self.device)
		model.eval()

		iter_wrapper = (lambda x: tqdm(x, total=len(train_loader))) if self.tqdm else (lambda x: x)
			
		fact_mapping = actions_order_dataset(unique=False)
		# store the relations that occur per active label in each example, so at the end we compute the mode
		# relation as the predicted rule
		relation_counts = DefaultDict(Counter)

		logger.info('aggregating MAP relation results')
		for batch_num, (sample, label) in enumerate(iter_wrapper(train_loader)):
			sample, label = sample.to(self.device), label.to(self.device)
			pred = model(sample)
			round_pred = pred.round()

			# yikes, but only need to run once
			# logger.info(f'batch num {batch_num}')
			for gt_label, pred_rela in zip(label, round_pred):
				active_relations = []
				for rel_idx, rel_val in enumerate(pred_rela):
					# if that relation is predicted
					if rel_val.bool().item():
						active_relations.append(fact_mapping[rel_idx])

				for label, label_val in enumerate(gt_label):
					# for every active label, add those predicted relations
					if label_val.bool().item():
						relation_counts[label].update(active_relations)

		logger.info('evaluating MAP rules')
		self.evaluate_model_voc(relation_counts)

	def generate_MAP_rules(self, relation_counts, k=10):
		pred_rules = []
		for label_id in range(len(relation_counts)):
			counter = relation_counts[label_id]
			if len(relation_counts[label_id]) == 0:
				pred_rules.append([])
				continue

			relations, counts = zip(*counter.most_common(k))
			relations = list(relations)

			# add reverse relations
			for relation in relations:
				cons_fact = reverse(relation)
				if cons_fact not in relations:
					relations.append(cons_fact)

			pred_rules.append(relations)
		return pred_rules

	def evaluate_rule_accuracy(self, relation_counts, k=10):
		pred_rules = self.generate_MAP_rules(relation_counts, k=10)
		gt_rules = actions_order_dataset()
		correct = 0
		for gt_rule, pred_rule in zip(gt_rules, pred_rules):

			# if the gt rule is in the top k predicted rules for that label
			if gt_rule in pred_rule:
				correct += 1
		return correct/len(gt_rules)

	def evaluate_model_voc(self, relation_counts):
		for k in [1, 5, 10]:
			acc = self.evaluate_rule_accuracy(relation_counts, k=k)
			logger.info(f'Rule accuracy @{k}: {acc}')

class ModeBaseline(TemporalRelationTraining):

	def train(self, load=False, samples=None, gt_atomic=False):
		train_loader = self.get_loader('train_subsetT', samples=samples, gt_atomic=gt_atomic)
		val_loader = self.get_loader('val', samples=samples, gt_atomic=gt_atomic)
		self.train_model(train_loader, val_loader)

	def train_model(self, train_loader, val_loader):
		iter_wrapper = (lambda x: tqdm(x, total=len(train_loader))) if self.tqdm else (lambda x: x)
			
		composite_counts = Counter()
		active_values = []

		logger.info('aggregating mode relation results')
		for batch_num, (sample, batch_labels) in enumerate(iter_wrapper(train_loader)):
			# sample, batch_labels = sample.to(self.device), label.to(self.device)

			for label in batch_labels:
				active_values.append(label.sum().item())
				active_labels = torch.where(label > 0)[0].tolist()
				composite_counts.update(active_labels)

		mean_active = floor(sum(active_values)/len(active_values))
		logger.info(f'mean active labels {mean_active}')
		logger.info(f'label counts')
		logger.info(composite_counts)
		values, counts = zip(*composite_counts.most_common(mean_active))
		
		pred_label = np.zeros(len(label))
		pred_label[list(values)] = 1.0

		# logger.info('evaluating MAP rules')
		self.evaluate_model_voc(val_loader, pred_label)

	def evaluate_model_voc(self, val_loader, pred_label):
		all_labels = []
		for sample, label in val_loader:
			all_labels.append(label.cpu().numpy())
		
		all_labels = np.vstack(all_labels)
		# we are usingt he same predicted label from the mode
		all_predictions = np.tile(pred_label, (all_labels.shape[0], 1))
		ap = compute_multiple_aps(all_labels, all_predictions)
		mAP = compute_mAP(ap)
		logger.info(f'validation mAP {mAP}')
		return mAP

class GeneratedTraining(TemporalRelationTraining):

	def get_model(self):
		time_dim_quant = self.config.time_dim_quant
		model = TemporalRelationNetwork(num_comp_events=len(self.rules), atomic_events=ACTION_CLASSES, relations=ORDERING, time_dim=MAX_FRAMES, \
			time_dim_quant=time_dim_quant, time_mult=True, conv_fill=1.0, agg_mode=self.config.agg_mode, agg_type=self.config.agg_type)

		if self.config.freeze_rela:
			self.fix_relations(model)
		if self.config.freeze_mapping:
			self.fix_mapping(model)
		return model

	def get_dataset(self, timelines):
		return GeneratedFeatures(timelines, output_dim=len(self.rules))

	def get_loader(self, split):
		file = get_path(split, self.sim_data_path)

		with open(file) as data_file:
			timelines = json.load(data_file)
		
		dataset = self.get_dataset(timelines)

		feature_loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
		return feature_loader

	def generate_rules(self, model, train_loader, val_loader, cv_loader=None):
		pass

	def train(self):
		folder = folder_format(self.config.gen_len, self.config.gen_events, self.config.gen_samples, self.config.gen_rules_beam)
		self.sim_data_path = os.path.join(self.config.gen_path, folder)
		self.rules = generate_rules(n_predicates=self.config.gen_len)
		
		if self.config.load_model:
			model = self.load_model()
		else:
			model = self.get_model()
		
		pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
		logger.info(f'num trainable parameters: {pytorch_total_params}')

		train_loader = self.get_loader('train')
		cv_loader = self.get_loader('val')
		val_loader = self.get_loader('test')
		# model, optimizer = self.train_model(model, cv_loader, cv_loader, cv_loader=cv_loader)
		model, optimizer = self.train_model(model, train_loader, val_loader, cv_loader=cv_loader)
		self.generate_rules(model, train_loader, val_loader, cv_loader=cv_loader)

	def evaluate_rule_accuracy(self, model):
		pred_rules = model.reasoning_network.build_top_k_rules_dynamic(k=10, append_thresh=0.01)
		gt_rules = self.rules		
		correct = 0
		for gt_rule, pred_rule in zip(gt_rules, pred_rules):
			if gt_rule == tuple(pred_rule):
				correct += 1
		return correct/len(gt_rules)

	def test_fixed_rule_accuracy(self, model):
		pred_rules = model.reasoning_network.build_top_k_rules(k=1)
		gt_rules = self.rules		
		correct = 0
		for gt_rule, pred_rule in zip(gt_rules, pred_rules):
			if gt_rule[0] in pred_rule:
				correct += 1
		return correct/len(gt_rules)

	def evaluate_model_voc(self, model, val_loader):
		mAP = super(TemporalRelationTraining, self).evaluate_model_voc(model, val_loader)
		acc = self.evaluate_rule_accuracy(model)
		# acc = self.test_fixed_rule_accuracy(model)
		logger.info(f'Rule accuracy: {acc}')
		return mAP

class GeneratedTrainingVariable(GeneratedTraining):

	def get_model(self):
		time_dim_quant = self.config.time_dim_quant
		model = TemporalRelationNetworkVariable(num_comp_events=len(self.rules), atomic_events=ACTION_CLASSES, relations=ORDERING, time_dim=MAX_FRAMES, \
			time_dim_quant=time_dim_quant, time_mult=True, conv_fill=1.0, agg_mode=self.config.agg_mode, agg_type=self.config.agg_type, \
				max_rule_len=self.config.gen_len, attn_cand_beam=self.config.gen_len_beam, variable_len=self.variable_len, max_rules_beam=self.config.gen_rules_beam)
				
		if self.config.freeze_rela:
			self.fix_relations(model)
		if self.config.freeze_mapping:
			self.fix_mapping(model)
		return model

	def save_model(self, model):
		save_dict = {'proj_weight': self.get_proj_weights(model), 'model_state_dict': model.state_dict()}
		torch.save(save_dict, self.save_path)

	def load_model(self):
		model = self.get_model()
		model_dict = model.state_dict()
		
		logger.info(f'Loading model from: {self.save_path}')
		load_dict = torch.load(self.save_path, map_location=torch.device(self.device))

		self.pre_proj_weight = load_dict['proj_weight']
		checkpoint = load_dict['model_state_dict']
		# select parameters are tied to the beam width length, so if only don't those match, ie testing widths, just load the current width
		if checkpoint['reasoning_network_var.select'].shape != model_dict['reasoning_network_var.select'].shape:
			checkpoint.update({'reasoning_network_var.select': model_dict['reasoning_network_var.select']})
		model.load_state_dict(checkpoint)
		return model

	def train(self):
		self.variable_len = not self.config.var_len_fixed

		folder = folder_format(self.config.gen_len, self.config.gen_events, self.config.gen_samples, self.config.gen_rules_beam)
		self.sim_data_path = os.path.join(self.config.gen_path, folder)
		# self.rules = generate_rules(n_predicates=self.config.gen_len)
		self.rules = generate_rules(n_predicates=self.config.gen_len, variable_len=self.variable_len, max_rules_beam=self.config.gen_rules_beam)
		
		if self.config.load_model:
			model = self.load_model()
		else:
			model = self.get_model()
			# keep track of projection matrix in the first training phase.
			self.pre_proj_weight = model.reasoning_network.rule_weights.detach().clone().to(self.device)
			# self.pre_proj_weights = deque([], maxlen=10) # keep track of the attention weights per batch, take a running average
		
		pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
		logger.info(f'num trainable parameters: {pytorch_total_params}')

		train_loader = self.get_loader('train')
		cv_loader = self.get_loader('val')
		val_loader = self.get_loader('test')
		
		model.reasoning_network_var.select.requires_grad = False  # do not perform gradient updates for this in the firs stage

		if not self.config.var_skip_proj_training:
			model, optimizer = self.train_model(model, train_loader, val_loader, cv_loader=cv_loader, save=not model.rule_search)
		else:
			if not self.load_model:
				logger.warning(f'WARNING: running rule generation using untrained attentions')
		self.generate_rules(model, train_loader, val_loader, cv_loader=cv_loader)


	def get_proj_weights(self, counts):
		fact_mapping = actions_order_dataset(n=2, unique=False)
		comp_events = len(self.rules)
		num_predicates = len(fact_mapping)
		weights = torch.zeros(comp_events, num_predicates)

		for label in range(comp_events):
			for predicate, count in counts[label].items():
				predicate_idx = fact_mapping.index(predicate)
				weights[label, predicate_idx] = count

		return weights

	# compute salient combinations MAP style with the trained model params
	def get_counts(self, model, train_loader):
		count_device = self.device
		model = model.to(count_device)
		model.eval()
		model.relations = True

		iter_wrapper = (lambda x: tqdm(x, total=len(train_loader))) if self.tqdm else (lambda x: x)
			
		fact_mapping = actions_order_dataset(unique=False)
		# store the relations that occur per active label in each example, so at the end we compute the mode
		# relation as the predicted rule
		relation_counts = DefaultDict(Counter)

		logger.info('aggregating MAP relation results')
		with torch.no_grad():
			for batch_num, (sample, label) in enumerate(iter_wrapper(train_loader)):
				sample, label = sample.to(count_device), label.to(count_device)
				pred = model(sample)
				round_pred = pred.round()

				# yikes, but only need to run once
				# logger.info(f'batch num {batch_num}')
				for gt_label, pred_rela in zip(label, round_pred):
					active_relations = []
					for rel_idx, rel_val in enumerate(pred_rela):
						# if that relation is predicted
						if rel_val.bool().item():
							active_relations.append(fact_mapping[rel_idx])

					for label, label_val in enumerate(gt_label):
						# for every active label, add those predicted relations
						if label_val.bool().item():
							relation_counts[label].update(active_relations)

		model = model.to(self.device)
		model.relations = False
		return relation_counts

	def generate_rules(self, model, train_loader, val_loader, cv_loader=None):
		logger.info('running rule generation training')
		self.loss_func = nn.BCELoss()
		self.epochs = self.config.gen_epochs

		self.fix_relations(model)  # dont update the relation parameters learned at this point
		attention_weights = model.reasoning_network.rule_weights
		attention_weights.requires_grad = False  # fix the attentions per label

		# relation_counts = self.get_counts(model, train_loader)
		# weights = self.get_proj_weights(relation_counts)
		weights = self.get_proj_weights(model)

		model.rule_search = True
		model.reasoning_network_var.select.requires_grad = True
		model.reasoning_network_var.gen_combs(weights)  # use the highest attention weights to guide the combinatorial rule search
		
		# var len rule generation is memory expensive, so reduce batch size if needed
		if self.batch_size != self.config.gen_len_batch_size:
			self.batch_size = self.config.gen_len_batch_size
			train_loader = self.get_loader('train')
			cv_loader = self.get_loader('val')
			val_loader = self.get_loader('test')

		model, optimizer = self.tune_var_model(model, train_loader, val_loader, cv_loader=cv_loader, save=False)

	def compute_loss(self, pred, label, model):
		pred, ts = pred
		if self.proj_interval > 0 or model.rule_search:
			# using BCELoss, so make sure values are [0, 1]
			pred = pred.clamp(min=ETA, max=1.0 - ETA)
		loss = self.loss_func(pred, label)
		# print(f'ce loss {loss.item()}')
		
		lambda_1 = self.config.l1_loss
		lambda_2 = self.config.cons_loss
		
		if lambda_1 > 0 and not model.rule_search:
			reason_weights = model.reasoning_network.rule_weights
			l1_loss_norm =  torch.norm(reason_weights, p=1) / reason_weights.shape.numel()
			
			# print(f'l1 loss {l1_loss_norm.item()}', reason_weights.shape)
			loss += lambda_1 * l1_loss_norm

		if lambda_2 > 0 and not model.rule_search:
			cons_loss = self.consistency_loss(ts)
			# print(f'consistency loss {cons_loss}')
			loss += lambda_2 * cons_loss

		if model.rule_search:
			comb_select = model.reasoning_network_var.select
			l1_comb_norm =  torch.norm(comb_select, p=1) / comb_select.shape.numel()		
			loss += 1.0 * l1_comb_norm

		return loss

	def add_proj_weight(self, sample):
		self.pre_proj_weight += sample.detach().clone()
	
	def get_proj_weights(self, model):
		if self.proj_interval > 0:
			proj_weights = self.pre_proj_weight
		else:
			proj_weights = model.reasoning_network.rule_weights
		return proj_weights

	def post_step(self, model, batch_num):
		self.add_proj_weight(model.reasoning_network.rule_weights)

		if self.proj_interval > 0 and batch_num % self.proj_interval == 0:
			reason_weights = model.reasoning_network.rule_weights.data
			reason_weights = reason_weights.clamp(0 + ETA, 1 - ETA)
			model.reasoning_network.rule_weights.data = reason_weights
		
		conv_scale_weight = model.temp_quant.conv_scalar.data
		conv_scale_weight = conv_scale_weight.clamp(0 + ETA, 1)
		model.temp_quant.conv_scalar.data = conv_scale_weight

	def evaluate_rule_accuracy(self, model, k=5):
		if model.rule_search:
			pred_rules = model.reasoning_network_var.build_var_len_rules(k=k)
		else:
			pred_rules = model.reasoning_network.build_top_k_rules_dynamic(k=k, append_thresh=0.05)

		#TODO: build consistent rules either enuemrate or dynamically, check lens, then check atomic events, 
		gt_rules = self.rules		
		correct = 0
		accuracies = defaultdict(list)
		for gt_rule, top_rules in zip(gt_rules, pred_rules):
			rule_len = len(gt_rule)
			top_match = any([check_rule_match(gt_rule, pred_rule) for pred_rule in top_rules])
			if top_match:
				correct += 1
				accuracies[rule_len].append(True)
			else:
				accuracies[rule_len].append(False)
		accuracies = {f'rule len {rule_len}': sum(correct)/len(correct) for rule_len, correct in accuracies.items()}
		return correct/len(gt_rules), accuracies

	def check_rule_recall(self, model):
		attention_weights = self.get_proj_weights(model)
		sorted_weights = torch.flip(torch.argsort(torch.abs(attention_weights), dim=1), dims=[1])
		fact_indices = model.reasoning_network_var.get_fact_indices(sorted_weights)
		model.reasoning_network_var.compute_stats(fact_indices, sorted_weights)

	def evaluate_model_voc(self, model, val_loader):
		mAP = super(TemporalRelationTraining, self).evaluate_model_voc(model, val_loader)
		if model.rule_search:
			for k in [1, 5, 10]:
				total_acc, accuracies = self.evaluate_rule_accuracy(model, k=k)
				logger.info(f'HITS@{k} overall: {total_acc} by len: {accuracies}')
		else:
			# for k in [1, 5, 10]:
			# 	total_acc, accuracies = self.evaluate_rule_accuracy(model, k=k)
			# 	logger.info(f'HITS@{k} overall: {total_acc} by len: {accuracies}')
			self.check_rule_recall(model)
		return mAP

	def tune_var_model(self, model, train_loader, val_loader, cv_loader=None, save=False):
		model = model.to(self.device)

		# if we have frozen some layers, don't compute those grads
		params = filter(lambda p: p.requires_grad, model.parameters())
		# print(list(filter(lambda p: p.requires_grad, model.parameters())))
		# if len(list(params)) > 0:
		optimizer = torch.optim.Adam(params, lr=self.lr)
		
		logger.info(f'Training on device: {self.device}')
		iter_wrapper = (lambda x: tqdm(x, total=len(train_loader))) if self.tqdm else (lambda x: x)

		save_model = deepcopy(model)
		best_mAP = 0.0
		batch_num = 1

		for epoch in range(1, self.epochs + 1):
			losses = []
			for sample, label in iter_wrapper(train_loader):
				model.train()
				sample, label = sample.to(self.device), label.to(self.device)
				optimizer.zero_grad()
				pred = model(sample)
				loss = self.compute_loss(pred, label, model)
				losses.append(loss.item())
				loss.backward()
				optimizer.step()
				self.post_step(model, batch_num)
				if batch_num % 1 == 0:
					mAP = super(TemporalRelationTraining, self).evaluate_model_voc(model, cv_loader, print_score=False)
					if mAP >= best_mAP:
						#TODO: have the saved model in the cpu instead of taking up GPU memory, if so only option is write to disk
						save_model = deepcopy(model)
						best_mAP = mAP
						# logger.info(f'Best mAP so far {mAP}')
						if save:
							self.save_model(model)
							logger.info('saving model to disk')
				batch_num += 1
				# print(model.temp_quant.conv_scalar)
			logger.info(f'Epoch {epoch} mean batch loss = {np.mean(losses)}')
			self.evaluate_model_voc(save_model, cv_loader)
		logger.info('evaluating test data')
		self.evaluate_model_voc(save_model, val_loader)
		return model, optimizer

class GeneratedTrainingMAP(GeneratedTraining):

	def get_model(self):
		# standard network, but without the reasoning layer, use the rel_network and aggreagte the results
		time_dim_quant = self.config.time_dim_quant
		model = TemporalMAP(atomic_events=ACTION_CLASSES, relations=ORDERING, time_dim=MAX_FRAMES, \
			time_dim_quant=time_dim_quant, time_mult=True, conv_fill=1.0, agg_mode=self.config.agg_mode, agg_type=self.config.agg_type)
		
		self.fix_relations(model)
		return model

	def get_dataset(self, timelines):
		return GeneratedFeatures(timelines, output_dim=len(self.rules), mode='agg_before')

	def train(self, load=False, samples=None, gt_atomic=False):
		folder = folder_format(self.config.gen_len, self.config.gen_events, self.config.gen_samples, self.config.gen_rules_beam)
		self.sim_data_path = os.path.join(self.config.gen_path, folder)

		self.variable_len = not self.config.var_len_fixed
		self.rules = generate_rules(n_predicates=self.config.gen_len, variable_len=self.variable_len, max_rules_beam=self.config.gen_rules_beam)

		if self.config.load_model:
			model = self.load_model()
		else:
			model = self.get_model()
		
		pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
		logger.info(f'num trainable parameters: {pytorch_total_params}')
		
		train_loader = self.get_loader('train')
		self.train_model(model, train_loader)

	def train_model(self, model, train_loader):
		model = model.to(self.device)
		model.eval()

		iter_wrapper = (lambda x: tqdm(x, total=len(train_loader))) if self.tqdm else (lambda x: x)
			
		fact_mapping = actions_order_dataset(unique=False)
		# store the relations that occur per active label in each example, so at the end we compute the mode
		# relation as the predicted rule
		relation_counts = DefaultDict(Counter)

		logger.info('aggregating MAP relation results')
		for batch_num, (sample, label) in enumerate(iter_wrapper(train_loader)):
			sample, label = sample.to(self.device), label.to(self.device)
			pred = model(sample)
			round_pred = pred.round()

			# yikes, but only need to run once
			# logger.info(f'batch num {batch_num}')
			for gt_label, pred_rela in zip(label, round_pred):
				active_relations = []
				for rel_idx, rel_val in enumerate(pred_rela):
					# if that relation is predicted
					if rel_val.bool().item():
						active_relations.append(fact_mapping[rel_idx])

				for label, label_val in enumerate(gt_label):
					# for every active label, add those predicted relations
					if label_val.bool().item():
						relation_counts[label].update(active_relations)

		logger.info('evaluating MAP rules')
		logger.info('evaluating test data')
		self.evaluate_model_voc(relation_counts)

	def generate_MAP_rules(self, relation_counts, k=10, append_thresh=0.1):
		pred_rules = []
		for label_id in range(len(relation_counts)):
			counter = relation_counts[label_id]
			if len(relation_counts[label_id]) == 0:
				pred_rules.append([])
				continue

			relations, counts = zip(*counter.most_common(self.config.gen_len_beam))
			relations = list(relations)
			scores = np.array(counts)
			scores = scores/scores.sum()

			unique_rel = []
			unique_scores = []
			for relation, score in zip(relations, scores):
				if relation not in unique_rel and reverse(relation) not in unique_rel:
					unique_rel.append(relation)
					unique_scores.append(score)

			rule_len = self.config.gen_len
			if rule_len > 1:
				# combinatorial search over the k relations
				relations = list(zip(unique_rel, unique_scores))
				num_combs = min(len(relations), rule_len)
				all_combinations = list(combinations(relations, num_combs))
				rules = []
				for comb in all_combinations:
					rule = []
					prev_score = comb[0][1]
					for relation in comb[:num_combs]:
						predicate, score = relation
						if prev_score - score < (append_thresh): #/rule_len):
							rule.append(predicate)
							prev_score = score
						else:
							break
					rules.append(tuple(rule))
			else:
				rules = [((relation),) for relation in unique_rel]

			rules = rules[:k]
			pred_rules.append(rules)
		return pred_rules

	def evaluate_rule_accuracy(self, relation_counts, k=10):
		all_pred_rules = self.generate_MAP_rules(relation_counts, k=k, append_thresh=self.config.append_thresh)
		gt_rules = self.rules
		accuracies = defaultdict(list)
		correct = 0
		for gt_rule, pred_rules in zip(gt_rules, all_pred_rules):
			rule_len = len(gt_rule)
			top_match = any([check_rule_match(gt_rule, pred_rule) for pred_rule in pred_rules])
			# if the gt rule is in the top k predicted rules for that label
			if top_match:
				correct += 1
				accuracies[rule_len].append(True)
			else:
				accuracies[rule_len].append(False)
		accuracies = {f'rule len {rule_len}': sum(correct)/len(correct) for rule_len, correct in accuracies.items()}
		return correct/len(gt_rules), accuracies

	def evaluate_model_voc(self, relation_counts):
		for k in [1, 5, 10]:
			total_acc, accuracies = self.evaluate_rule_accuracy(relation_counts, k=k)
			logger.info(f'HITS@{k} overall: {total_acc} by len: {accuracies}')

def freeze_params(model):
	for p in model.parameters():
		p.requires_grad = False

# run MAP to generate the weights, then run var TLP model for rule selection
class GeneratedTrainingVariableMAP(GeneratedTraining):

	def get_model(self):
		time_dim_quant = self.config.time_dim_quant
		model = TemporalRelationNetworkVariable(num_comp_events=len(self.rules), atomic_events=ACTION_CLASSES, relations=ORDERING, time_dim=MAX_FRAMES, \
			time_dim_quant=time_dim_quant, time_mult=True, conv_fill=1.0, agg_mode=self.config.agg_mode, agg_type=self.config.agg_type, \
				max_rule_len=self.config.gen_len, attn_cand_beam=self.config.gen_len_beam, variable_len=self.variable_len, max_rules_beam=self.config.gen_rules_beam)

		freeze_params(model)
		model.reasoning_network_var.select.requires_grad = True
		return model

	# compute MAP statistics
	def train(self, load=False, samples=None, gt_atomic=False):
		folder = folder_format(self.config.gen_len, self.config.gen_events, self.config.gen_samples, self.config.gen_rules_beam)
		self.sim_data_path = os.path.join(self.config.gen_path, folder)

		self.variable_len = not self.config.var_len_fixed
		self.rules = generate_rules(n_predicates=self.config.gen_len, variable_len=self.variable_len, max_rules_beam=self.config.gen_rules_beam)

		time_dim_quant = self.config.time_dim_quant
		model = TemporalMAP(atomic_events=ACTION_CLASSES, relations=ORDERING, time_dim=MAX_FRAMES, \
			time_dim_quant=time_dim_quant, time_mult=True, conv_fill=1.0, agg_mode=self.config.agg_mode, agg_type=self.config.agg_type)
		
		# self.fix_relations(model)
		freeze_params(model)
		
		pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
		logger.info(f'num trainable parameters: {pytorch_total_params}')
		
		train_loader = self.get_loader('train')
		counts = self.train_model(model, train_loader)
		self.generate_rules(counts)

	def train_model(self, model, train_loader, atomic_events=ACTION_CLASSES):
		model = model.to(self.device)
		model.eval()

		iter_wrapper = (lambda x: tqdm(x, total=len(train_loader))) if self.tqdm else (lambda x: x)
			
		fact_mapping = actions_order_dataset(atomic_events=atomic_events, unique=False)
		# store the relations that occur per active label in each example, so at the end we compute the mode
		# relation as the predicted rule
		relation_counts = DefaultDict(Counter)

		logger.info('aggregating MAP relation results')
		for batch_num, (sample, label) in enumerate(iter_wrapper(train_loader)):
			sample, label = sample.to(self.device), label.to(self.device)
			pred = model(sample)
			round_pred = pred.round()

			# yikes, but only need to run once
			# logger.info(f'batch num {batch_num}')
			for gt_label, pred_rela in zip(label, round_pred):
				active_relations = []
				for rel_idx, rel_val in enumerate(pred_rela):
					# if that relation is predicted
					if rel_val.bool().item():
						active_relations.append(fact_mapping[rel_idx])

				for label, label_val in enumerate(gt_label):
					# for every active label, add those predicted relations
					if label_val.bool().item():
						relation_counts[label].update(active_relations)

		logger.info('evaluating MAP rules')
		logger.info('evaluating test data')
		return relation_counts

	# given the MAP counts, train the TLP attention to select the rules
	def generate_rules(self, counts):
		logger.info('running rule generation training')
		model = self.get_model()
		model.rule_search = True
		self.loss_func = nn.BCELoss()
		self.epochs = self.config.gen_epochs

		# self.fix_relations(model)  # dont update the relation parameters learned at this point
		attention_weights = model.reasoning_network.rule_weights
		attention_weights.requires_grad = False  # fix the attentions per label

		# get the weights (labels, predicates)
		weights = self.get_proj_weights(counts)

		model.reasoning_network_var.select.requires_grad = True
		model.reasoning_network_var.gen_combs(weights)  # use the highest attention weights to guide the combinatorial rule search
		
		# var len rule generation is memory expensive, so reduce batch size if needed
		# if self.batch_size != self.config.gen_len_batch_size:
		self.batch_size = self.config.gen_len_batch_size
		train_loader = self.get_loader('train')
		cv_loader = self.get_loader('val')
		val_loader = self.get_loader('test')

		# model, optimizer = self.train_model(model, train_loader, val_loader, cv_loader=cv_loader, save=False)
		model, optimizer = self.tune_var_model(model, train_loader, val_loader, cv_loader=cv_loader, save=False)

	def compute_loss(self, pred, label, model):
		pred, ts = pred
		if self.proj_interval > 0 or model.rule_search:
			# using BCELoss, so make sure values are [0, 1]
			pred = pred.clamp(min=ETA, max=1.0 - ETA)
		loss = self.loss_func(pred, label)
		# print(f'ce loss {loss.item()}')
		
		lambda_1 = self.config.l1_loss
		lambda_2 = self.config.cons_loss
		
		if lambda_1 > 0 and not model.rule_search:
			reason_weights = model.reasoning_network.rule_weights
			l1_loss_norm =  torch.norm(reason_weights, p=1) / reason_weights.shape.numel()
			
			# print(f'l1 loss {l1_loss_norm.item()}', reason_weights.shape)
			loss += lambda_1 * l1_loss_norm

		if lambda_2 > 0 and not model.rule_search:
			cons_loss = self.consistency_loss(ts)
			# print(f'consistency loss {cons_loss}')
			loss += lambda_2 * cons_loss

		if model.rule_search:
			comb_select = model.reasoning_network_var.select
			l1_comb_norm =  torch.norm(comb_select, p=1) / comb_select.shape.numel()		
			loss += 1.0 * l1_comb_norm

		return loss

	def add_proj_weight(self, sample):
		self.pre_proj_weight += sample.detach().clone()
	
	def get_proj_weights(self, counts):
		fact_mapping = actions_order_dataset(n=2, unique=False)
		comp_events = len(self.rules)
		num_predicates = len(fact_mapping)
		weights = torch.zeros(comp_events, num_predicates)

		for label in range(comp_events):
			for predicate, count in counts[label].items():
				predicate_idx = fact_mapping.index(predicate)
				weights[label, predicate_idx] = count

		return weights

	def post_step(self, model, batch_num):
		self.add_proj_weight(model.reasoning_network.rule_weights)

		if self.proj_interval > 0 and batch_num % self.proj_interval == 0:
			reason_weights = model.reasoning_network.rule_weights.data
			reason_weights = reason_weights.clamp(0 + ETA, 1 - ETA)
			model.reasoning_network.rule_weights.data = reason_weights

	def evaluate_rule_accuracy(self, model, k=5):
		if model.rule_search:
			pred_rules = model.reasoning_network_var.build_var_len_rules(k=k)
		else:
			pred_rules = model.reasoning_network.build_top_k_rules_dynamic(k=k, append_thresh=0.05)

		#TODO: build consistent rules either enuemrate or dynamically, check lens, then check atomic events, 
		gt_rules = self.rules		
		correct = 0
		accuracies = defaultdict(list)
		for gt_rule, top_rules in zip(gt_rules, pred_rules):
			rule_len = len(gt_rule)
			top_match = any([check_rule_match(gt_rule, pred_rule) for pred_rule in top_rules])
			if top_match:
				correct += 1
				accuracies[rule_len].append(True)
			else:
				accuracies[rule_len].append(False)
		accuracies = {f'rule len {rule_len}': sum(correct)/len(correct) for rule_len, correct in accuracies.items()}
		return correct/len(gt_rules), accuracies

	def check_rule_recall(self, model):
		attention_weights = self.get_proj_weights(model)
		sorted_weights = torch.flip(torch.argsort(torch.abs(attention_weights), dim=1), dims=[1])
		fact_indices = model.reasoning_network_var.get_fact_indices(sorted_weights)
		model.reasoning_network_var.compute_stats(fact_indices, sorted_weights)

	def evaluate_model_voc(self, model, val_loader):
		mAP = super(TemporalRelationTraining, self).evaluate_model_voc(model, val_loader)
		if model.rule_search:
			for k in [1, 5, 10]:
				total_acc, accuracies = self.evaluate_rule_accuracy(model, k=k)
				logger.info(f'HITS@{k} overall: {total_acc} by len: {accuracies}')
		else:
			for k in [1, 5, 10]:
				total_acc, accuracies = self.evaluate_rule_accuracy(model, k=k)
				logger.info(f'HITS@{k} overall: {total_acc} by len: {accuracies}')
			self.check_rule_recall(model)
		return mAP

	def tune_var_model(self, model, train_loader, val_loader, cv_loader=None, save=True):
		model = model.to(self.device)

		# if we have frozen some layers, don't compute those grads
		params = filter(lambda p: p.requires_grad, model.parameters())
		# if len(list(params)) > 0:
		optimizer = torch.optim.Adam(params, lr=self.lr)
		
		logger.info(f'Training on device: {self.device}')
		iter_wrapper = (lambda x: tqdm(x, total=len(train_loader))) if self.tqdm else (lambda x: x)

		save_model = deepcopy(model)
		best_mAP = 0.0
		batch_num = 1

		for epoch in range(1, self.epochs + 1):
			losses = []
			for sample, label in iter_wrapper(train_loader):
				model.train()
				sample, label = sample.to(self.device), label.to(self.device)
				optimizer.zero_grad()
				pred = model(sample)
				loss = self.compute_loss(pred, label, model)
				losses.append(loss.item())
				loss.backward()
				optimizer.step()
				# self.post_step(model, batch_num)
				if batch_num % 1 == 0:
					mAP = super(TemporalRelationTraining, self).evaluate_model_voc(model, cv_loader, print_score=False)
					if mAP >= best_mAP:
						#TODO: have the saved model in the cpu instead of taking up GPU memory, if so only option is write to disk
						save_model = deepcopy(model)
						best_mAP = mAP
						# logger.info(f'Best mAP so far {mAP}')
						if save:
							self.save_model(model)
							logger.info('saving model to disk')
				batch_num += 1
			logger.info(f'Epoch {epoch} mean batch loss = {np.mean(losses)}')
			self.evaluate_model_voc(save_model, cv_loader)
		logger.info('evaluating test data')
		self.evaluate_model_voc(save_model, val_loader)
		return model, optimizer