import numpy as np
import torch
import torch.nn as nn
from torch.nn import Parameter
from core.utils import \
    align_results, categorical_softmax_entropy,\
    set_seed, expand_and_permute
from core.sparse import \
    spmatmul, squeeze, sparse_vectormul, sparse_dropout_last
from core.logic import \
    Not, Copy, And, Xor, Or, Equals, Nand, DeduceTo, \
    Exist, ForAll, Variable, LogicalEquation, Wrapper, Attribute, Relation, \
    Binary, Unary, Quantified


class LogicNetElement(LogicalEquation):
    def __init__(self, variables, no_grad=False):
        super().__init__(no_grad)
        self.variables = variables

    def forward(self, batch, soft_logic, final_forall_logic=None):
        raise NotImplementedError()

    def str_info(self, replace_vars, all_variables):
        raise NotImplementedError()


class LogicInputLayer(LogicNetElement):
    def __init__(self, variables, attributes, relations, no_grad=False,
                 remove_inputs=[]):
        super().__init__(variables, no_grad)
        self.input_list = [
            Attribute(attribute_name, variable)
            for attribute_name in attributes for variable in self.variables
            if attribute_name not in remove_inputs
        ] + [
            Relation(relation_name, variable1, variable2)
            for relation_name in relations
            for variable1 in self.variables for variable2 in self.variables
            if variable1 != variable2
            if relation_name not in remove_inputs + ["Equal"]
        ]
        self.n_width = len(self.input_list)

    def inner_scoring(self, batch, soft_logic):
        result_list = [
            _element.scoring(batch, soft_logic)
            for _element in self.input_list
        ]
        scores, valid_mask = align_results(result_list, self.variables, 1)
        return scores, valid_mask, self.variables

    def str_info(self, replace_vars, all_variables):
        info = self.input_list[self.repr_index.pop(0)].str_info(
            replace_vars, all_variables)
        return info

    def entropy(self):
        return 0


class LogicMetaLayer(LogicNetElement):
    repr_mode = "max"

    def __init__(self, last_layer, n_width, n_args, ops,
                 init_var, temperature, no_grad=False):
        super().__init__(last_layer.variables, no_grad)
        self.init_var = init_var
        self.temperature = temperature
        self.last_layer = last_layer
        input_width = self.last_layer.n_width
        self.n_width = n_width
        self.n_args = n_args
        # ops: [(op_type, extra_info), ...]
        self.ops = ops
        self.arg1_weights = Parameter(
            torch.randn(input_width, n_width) * init_var
        )
        assert n_args in (1, 2)
        if n_args == 2:
            self.arg2_weights = Parameter(
                torch.randn(input_width, n_width) * init_var
            )
        else:
            assert all([issubclass(_op[0], Unary) for _op in ops])
        self.op_weights = Parameter(
            torch.randn(n_width, len(ops)) * init_var
        )

    def inner_scoring(self, batch, soft_logic):
        input_scores, valid_mask, variables = self.last_layer.scoring(
            batch, soft_logic)
        bs = input_scores.shape[0]
        n_nodes = input_scores.shape[-1]
        nv = self.n_vars
        nw = self.n_width
        assert variables == self.variables
        op_weights = (self.op_weights * self.temperature).softmax(1)
        arg1_weights = (self.arg1_weights * self.temperature).softmax(0)
        arg1 = spmatmul(input_scores, arg1_weights, 1)
        if self.n_args == 2:
            arg2_weights = (self.arg2_weights * self.temperature).softmax(0)
            arg2 = spmatmul(input_scores, arg2_weights, 1)
        if valid_mask is not None:
            valid_mask = valid_mask.min(1)[0].unsqueeze(1).repeat(
                1, nw, *((1,)*nv)).view(*arg1.shape)
        op_args = []
        for _op, _info in self.ops:
            if issubclass(_op, Quantified):
                op_args.append(
                    (arg1, valid_mask, variables, _info["reduce_var"]))
            else:
                if issubclass(_op, Unary):
                    op_args.append((arg1, valid_mask, variables))
                else:
                    op_args.append((arg1, arg2, valid_mask, variables))

        if arg1.is_sparse:
            layer_scores = torch.sparse_coo_tensor(
                [[]] * (2+nv), [], (bs, nw, *((n_nodes,)*nv)),
                device=arg1.device, dtype=arg1.dtype
            )
            layer_mask = None
            for _i, ((_op, _), _arg) in enumerate(zip(self.ops, op_args)):
                op_score, op_mask, variables = _op.static_scoring(
                    _arg, soft_logic)
                layer_scores = layer_scores + sparse_vectormul(
                    op_score, op_weights[:, _i], 1
                )

        else:
            layer_scores = torch.zeros(
                (bs, nw, *((n_nodes,)*nv)),
                device=arg1.device, dtype=arg1.dtype)
            layer_mask = torch.ones(
                (bs, nw, *((n_nodes,)*nv)), device=arg1.device, dtype=bool)
            for (_op, _), _arg, _wght in zip(
                    self.ops, op_args, op_weights.split(1, 1)):
                op_score, op_mask, variables = _op.static_scoring(
                    _arg, soft_logic)
                op_score = expand_and_permute(
                    op_score, variables, self.variables)
                layer_scores = layer_scores + \
                    op_score * _wght.view(1, -1, *((1,)*nv))
                if op_mask is not None:
                    op_mask = expand_and_permute(
                        op_mask, variables, self.variables)
                    layer_mask = layer_mask.min(op_mask)

        if layer_scores.is_sparse:
            layer_scores = sparse_dropout_last(
                layer_scores, self.sparse_dropout_last)
        return layer_scores, layer_mask, self.variables

    def str_info(self, replace_vars, all_variables):
        index = self.repr_index.pop(0)

        def _select_fn(weights):
            assert weights.ndim == 1
            if self.repr_mode == "max":
                return weights.max(0)
            elif self.repr_mode == "sampling":
                choice_num = weights.shape[0]
                choice_index = np.random.choice(
                    range(choice_num), p=weights.detach().cpu().numpy())
                return weights[choice_index], choice_index
            else:
                raise Exception()

        arg1_confidence, arg1_index = _select_fn((
            self.arg1_weights[:, index] * self.temperature
        ).softmax(0))
        if self.n_args == 2:
            arg2_confidence, arg2_index = _select_fn((
                self.arg2_weights[:, index] * self.temperature
            ).softmax(0))
        op_confidence, op_index = _select_fn((
            self.op_weights[index] * self.temperature
        ).softmax(0))

        last_layer = self.last_layer
        _op_type, _info = self.ops[op_index]
        if issubclass(_op_type, Unary):
            last_layer.repr_index = [arg1_index]
            op = _op_type(last_layer)
        elif issubclass(_op_type, Binary):
            last_layer.repr_index = [arg1_index, arg2_index]
            op = _op_type(last_layer, last_layer)
        else:
            last_layer.repr_index = [arg1_index]
            op = _op_type(_info["reduce_var"], last_layer)
        info = op.str_info(replace_vars, all_variables)
        info["confidence"] *= op_confidence.cpu().detach().numpy()
        info["confidence"] *= arg1_confidence.cpu().detach().numpy()
        if issubclass(_op_type, Binary):
            info["confidence"] *= arg2_confidence.cpu().detach().numpy()
        return info

    def entropy(self):
        entropy = (
            categorical_softmax_entropy(self.arg1_weights, 0) +
            categorical_softmax_entropy(self.op_weights, 1)
        )
        if self.n_args == 2:
            entropy = entropy + categorical_softmax_entropy(
                self.arg2_weights, 0)
        entropy = entropy.sum()
        return entropy


class LogicFCLayer(LogicMetaLayer):
    def __init__(self, layer, n_width,
                 init_var, temperature, offset, no_grad=False):
        variables = layer.variables
        super().__init__(
            layer, n_width, 2,
            [(Not, {}), (Copy, {}), (And, {}), (Xor, {}), (Or, {}),
             (Equals, {}), (Nand, {}), (DeduceTo, {})] + [
                 (Exist, [_var]) for _var in variables
             ] + [(ForAll, [_var]) for _var in variables],
            init_var, temperature, offset, no_grad)


class LogicFinalLayer(LogicFCLayer):
    def __init__(self, layer, init_var, temperature, no_grad=False):
        super().__init__(layer, 1, init_var, temperature, no_grad)

    def inner_scoring(self, batch, soft_logic):
        layer_scores, valid_mask, variables = super().inner_scoring(
            batch, soft_logic
        )
        layer_scores = layer_scores.squeeze(1)
        if valid_mask is not None:
            valid_mask = valid_mask.squeeze(1)
        return layer_scores, valid_mask, variables

    def str_info(self, replace_vars, all_variables):
        self.repr_index = [0]
        return super().str_info(replace_vars, all_variables)


class LogicSelectionLayer(LogicMetaLayer):
    def __init__(self, layer, n_width,
                 init_var, temperature, no_grad=False):
        super().__init__(
            layer, n_width, 1,
            [(Copy, {})], init_var, temperature, no_grad)


class LogicFinalSelectionLayer(LogicSelectionLayer):
    def __init__(self, layer,
                 init_var, temperature, no_grad=False):
        super().__init__(
            layer, 1, init_var, temperature, no_grad)

    def inner_scoring(self, batch, soft_logic):
        arg1, valid_mask, variables = super().inner_scoring(batch, soft_logic)
        scores = squeeze(arg1, 1)
        if valid_mask is not None:
            valid_mask = valid_mask.squeeze(1)
        return scores, valid_mask, self.variables

    def str_info(self, replace_vars, all_variables):
        self.repr_index = [0]
        return super().str_info(replace_vars, all_variables)


class LogicIndexLayer(LogicNetElement):
    def __init__(self, layer, index, no_grad=False):
        super().__init__(layer.variables)
        self.last_layer = layer
        self.index = index

    def inner_scoring(self, batch, soft_logic):
        scores, valid_mask, variables = self.last_layer.scoring(
            batch, soft_logic)
        scores = scores.select(1, self.index)
        if valid_mask is not None:
            valid_mask = valid_mask.select(1, self.index)
        return scores, valid_mask, self.variables

    def str_info(self, replace_vars, all_variables):
        self.last_layer.repr_index = [self.index]
        return self.last_layer.str_info(replace_vars, all_variables)

    def entropy(self):
        return 0


class LogicNetTemplate(Wrapper):
    def classify(self, batch, y, soft_logic, final_forall_logic, loss_fn,
                 distinct_variables=False):
        scores, score_tensor, valid_mask = self(
            batch, soft_logic, final_forall_logic, distinct_variables)
        if loss_fn == "nll":
            loss = - (0.5 - y * (0.5 - scores)).log()
        elif loss_fn == "square":
            loss = (0.5 + y * (0.5 - scores)).pow(2)
        else:
            raise NotImplementedError()
        return loss, scores, score_tensor

    def str_info(self, replace_vars, all_variables):
        return self.inner_formula.str_info(replace_vars, all_variables)

    def __repr__(self):
        info = self.str_info({}, self.all_variables)
        return f"confidence={info['confidence']},  " + info["str"]

    def entropy(self):
        entropy = self.inner_formula.entropy()
        for _formula in self.formula_list:
            entropy = entropy + _formula.entropy()
        return entropy


class LogicNet(LogicNetTemplate):
    def __init__(self, n_variables, n_layers, n_width, attributes, relations,
                 init_var, temperature, seed, no_grad=False, remove_inputs=[]):
        super().__init__(no_grad)
        set_seed(seed)
        self.variables = tuple([Variable(f"X{i}") for i in range(n_variables)])
        self.n_layers = n_layers
        self.n_width = n_width
        self.temperature = temperature
        # building layers
        layer_list = [LogicInputLayer(
            self.variables, attributes, relations,
            remove_inputs=remove_inputs)]
        layer = layer_list[0]
        for i in range(n_layers):
            layer = LogicFCLayer(
                layer, n_width,
                init_var, temperature)
            layer_list.append(layer)
        self.formula_list = nn.ModuleList(layer_list)
        self.inner_formula = LogicFinalLayer(
            self.variables, layer, init_var, temperature)

    @classmethod
    def from_graph(cls, graph, n_variables, n_layers, n_width,
                   init_var, temperature, no_grad=False):
        attributes = graph.attribute_names
        relations = graph.relation_names
        return cls(n_variables, n_layers, n_width, attributes, relations,
                   init_var, temperature, no_grad)

    @classmethod
    def from_graphs(cls, graphs, n_variables, n_layers, n_width,
                    init_var, temperature, no_grad=False):
        assert all(another.attribute_names ==
                   graphs[0].attribute_names for another in graphs[1:])
        assert all(another.relation_names ==
                   graphs[0].relation_names for another in graphs[1:])
        return cls.from_graph(
            graphs[0], n_variables, n_layers, n_width, init_var, temperature,
            no_grad)


class UnaryLayer(LogicMetaLayer):
    def __init__(self, layer, reduce_variable,
                 init_var, temperature, no_grad=False):
        super().__init__(
            layer, 1, 1,
            [(Copy, {}), (Not, {}), (Exist, {"reduce_var": reduce_variable})],
            init_var, temperature, no_grad)
        self.reduce_variable = reduce_variable


class DeduceNet(LogicNetTemplate):
    def __init__(self, n_variables, n_layers, n_width, deduce_attribute,
                 attributes, relations,
                 init_var, temperature, seed, no_grad=False):
        super().__init__(no_grad)
        self.condition = LogicNet(n_variables, n_layers, n_width, attributes,
                                  relations, init_var, temperature, seed,
                                  remove_inputs=[deduce_attribute])
        self.variables = self.condition.variables
        self.reduce_variable = Variable("X0")
        target = Attribute(deduce_attribute, Variable("X0"))
        unary_arg = (self.reduce_variable, init_var, temperature)
        self.target = UnaryLayer((self.reduce_variable,), target, *unary_arg)
        self.inner_formula = Equals(
            self.condition, self.target
        )

    def entropy(self):
        return self.target.entropy() + self.condition.entropy()

    @classmethod
    def from_graph(cls, graph, n_variables, n_layers, n_width,
                   target,
                   init_var, temperature, no_grad=False):
        attributes = graph.attribute_names
        relations = graph.relation_names
        return cls(n_variables, n_layers, n_width, target,
                   attributes, relations,
                   init_var, temperature, no_grad)

    @classmethod
    def from_graphs(cls, graphs, n_variables, n_layers, n_width,
                    target,
                    init_var, temperature, no_grad=False):
        assert all(another.attribute_names ==
                   graphs[0].attribute_names for another in graphs[1:])
        assert all(another.relation_names ==
                   graphs[0].relation_names for another in graphs[1:])
        return cls.from_graph(
            graphs[0], n_variables, n_layers, n_width, target,
            init_var, temperature, no_grad)


class DeduceRelationNet(DeduceNet):
    def __init__(self, n_variables, n_layers, n_width, deduce_relation,
                 attributes, relations,
                 init_var, temperature, seed, no_grad=False):
        super(DeduceNet, self).__init__(no_grad)
        self.condition = LogicNet(n_variables, n_layers, n_width, attributes,
                                  relations, init_var, temperature, seed,
                                  remove_inputs=[deduce_relation])
        self.variables = self.condition.variables
        self.reduce_variable = Variable("X0")
        self.target = Relation(deduce_relation, Variable("X0"), Variable("X1"))
        self.inner_formula = Equals(
            self.condition, self.target
        )

    def entropy(self):
        return self.condition.entropy()
