import torch.nn as nn

from core.logic import Exist, Variable
from core.logic_net import \
    LogicFCLayer, LogicInputLayer, LogicFinalLayer, LogicNetTemplate
from core.utils import set_seed


class LogicFunctionTemplate(LogicNetTemplate):
    def classify(self, *arg, **kwarg):
        raise AttributeError()


class BinaryFunction(LogicFunctionTemplate):
    def __init__(self, n_variables, n_layers, n_width, attributes,
                 relations, init_var, temperature, seed,
                 no_grad=False):
        super().__init__(no_grad)
        set_seed(seed)
        self.variables = tuple([Variable(f"X{i}") for i in range(n_variables)])
        self.n_layers = n_layers
        self.n_width = n_width
        self.temperature = temperature
        layer_list = [LogicInputLayer(
            self.variables, attributes, relations)]
        layer = layer_list[0]
        for i in range(n_layers):
            layer = LogicFCLayer(
                self.variables, layer, n_width,
                init_var, temperature)
            layer_list.append(layer)
        layer = LogicFinalLayer(
            self.variables, layer, init_var, temperature)
        layer_list.append(layer)
        self.formula_list = nn.ModuleList(layer_list)
        self.inner_formula = Exist(
            self.variables[2:], layer)

    def entropy(self):
        entropy = sum(_formula.entropy() for _formula in self.formula_list)
        return entropy
