import torch
import numpy as np


def calculate_mutual_implication_loss(output: torch.Tensor, main_classes: list, attributes: list, combinations: list,
                                      mu: float = 10.0):
	# Class --> Attributes
	combinations = np.asarray(combinations)
	loss_fol_product_tnorm = []
	attribute_outputs = output[:, attributes]
	for i in main_classes:
		c = output[:, i]
		class_combination = torch.tensor(combinations[i, :], dtype=torch.bool)
		if torch.sum(class_combination.to(torch.int)) > 0:
			output_for_imply = attribute_outputs[:, class_combination]
			loss = c * torch.prod(1 - output_for_imply, dim=1)
			loss_fol_product_tnorm.append(loss)

	# Attribute --> Classes
	main_class_outputs = output[:, main_classes]
	for j_a, j in enumerate(attributes):
		a = output[:, j]
		attribute_combination = torch.tensor(combinations[:, j_a], dtype=torch.bool)
		if torch.sum(attribute_combination.to(torch.int)) > 0:
			output_for_imply = main_class_outputs[:, attribute_combination]
			loss = a * torch.prod(1 - output_for_imply, dim=1)
			loss_fol_product_tnorm.append(loss)

	# OR on the main classes
	output_or = (1 - output[:, np.asarray(main_classes)])
	loss = mu * torch.prod(output_or, dim=1)
	loss_fol_product_tnorm.append(loss)

	# OR on the attributes
	output_or = (1 - output[:, np.asarray(attributes)])
	loss = mu * torch.prod(output_or, dim=1)
	loss_fol_product_tnorm.append(loss)

	losses = torch.stack(loss_fol_product_tnorm, dim=1)

	loss_sum = torch.squeeze(torch.sum(losses))
	return loss_sum
