import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
# import tensorflow as tf
# import tensorflow.compat.v1 as tf
# tf.disable_v2_behavior()
# import time
from core.utils import set_seed, prob_to_logit, RunningMean
from core.sparse import sparse_scalarmul, squeeze, sparse_select

from core.logic import Variable, BinaryChainAnd, Or, Copy, And, Exist
from core.logic_net import \
    LogicInputLayer, LogicFinalSelectionLayer, LogicMetaLayer
from core.logic_fn import LogicFunctionTemplate

from model_base import LearnerBase


class DRUMFuntion(LogicFunctionTemplate):
    ''' For reproducing the results in DRUM '''
    def __init__(self, length, rank, relations,
                 init_var, temperature, no_grad=False):
        super().__init__(no_grad)
        n_variables = 2
        self.variables = tuple([Variable(f"X{i}") for i in range(n_variables)])
        self.length = length
        self.rank = rank
        self.temperature = temperature
        self.input_layer = LogicInputLayer(self.variables, [], relations)
        self.candidates = [LogicFinalSelectionLayer(
            self.input_layer, init_var, temperature)
                           for _ in range(rank * length)]
        chains = []
        for _i in range(rank):
            base = _i * length
            one_chain = self.candidates[base]
            for _j in range(1, length):
                one_chain = BinaryChainAnd(
                    one_chain, self.candidates[base+_j],
                    *reversed(self.variables))
            chains.append(one_chain)
        final = chains[0]
        for chain in chains[1:]:
            final = Or(final, chain)
        self.inner_formula = final
        self.formula_list = nn.ModuleList(self.candidates)

    def entropy(self):
        entropy = sum(_formula.entropy() for _formula in self.formula_list)
        return entropy


class UnifiedChainTemplate(LogicFunctionTemplate):
    def __init__(self, rank, width, depth, relations,
                 init_var, temperature, no_grad=False):
        super().__init__(no_grad)
        n_variables = 2
        self.variables = tuple([Variable(f"X{i}") for i in range(n_variables)])
        self.width = width
        self.depth = depth
        self.rank = rank
        self.init_var = init_var
        self.temperature = temperature
        self.input_layer = LogicInputLayer(self.variables, [], relations)
        self.build_model()

    def str_info(self, replace_vars, all_variables):
        return self.inner_formula.str_info(replace_vars, all_variables)


class UnifiedChain1(UnifiedChainTemplate):
    def build_model(self):
        layers = [self.input_layer]
        for i in range(self.depth):
            layers.append(
                LogicMetaLayer(
                    layers[-1], self.width, 2,
                    [(Copy, {}), (BinaryChainAnd, {})],
                    self.init_var, self.temperature
                )
            )
        for i in range(2):
            layers.append(
                LogicMetaLayer(
                    layers[-1], self.width, 2,
                    [(Copy, {}), (Or, {})], self.init_var, self.temperature
                )
            )
        self.inner_formula = LogicFinalSelectionLayer(
            layers[-1], self.init_var, self.temperature
        )
        self.formula_list = nn.ModuleList(layers)


class UnifiedChain2(UnifiedChainTemplate):
    def build_model(self):
        layers = [self.input_layer]
        for i in range(self.depth):
            layers.append(
                LogicMetaLayer(
                    layers[-1], self.width, 2,
                    [(Copy, {}), (BinaryChainAnd, {}), (And, {}), (Or, {}),
                     (Exist, {"reduce_var": [self.variables[0]]}),
                     (Exist, {"reduce_var": [self.variables[1]]})],
                    self.init_var, self.temperature
                )
            )
        self.inner_formula = LogicFinalSelectionLayer(
            layers[-1], self.init_var, self.temperature
        )
        self.formula_list = nn.ModuleList(layers)


class UnifiedChain3(UnifiedChainTemplate):
    def build_model(self):
        layers = [self.input_layer]
        for i in range(self.depth):
            layers.append(
                LogicMetaLayer(
                    layers[-1], self.width, 2,
                    [(Copy, {}), (BinaryChainAnd, {}), (And, {}), (Or, {}),
                     (Exist, {"reduce_var": [self.variables[0]]}),
                     (Exist, {"reduce_var": [self.variables[1]]})],
                    self.init_var, self.temperature
                )
            )
        end_layer = layers[-1]
        end_layer.caching = True
        for i in range(self.rank):
            layers.append(LogicFinalSelectionLayer(
                end_layer, self.init_var, self.temperature
            ))
        inner_formula = Or(layers[-1], layers[-2])
        for i in range(3, self.rank+1):
            inner_formula = Or(inner_formula, layers[-i])
        self.inner_formula = inner_formula
        self.formula_list = nn.ModuleList(layers)


class UnifiedChain4(UnifiedChainTemplate):
    def build_model(self):
        layers = [self.input_layer]
        for i in range(self.depth):
            layers.append(
                LogicMetaLayer(
                    layers[-1], self.width, 2,
                    [(Copy, {}), (BinaryChainAnd, {}), (And, {}), (Or, {})],
                    self.init_var, self.temperature
                )
            )
        end_layer = layers[-1]
        end_layer.caching = True
        for i in range(self.rank):
            layers.append(LogicFinalSelectionLayer(
                end_layer, self.init_var, self.temperature
            ))
        inner_formula = Or(layers[-1], layers[-2])
        for i in range(3, self.rank+1):
            inner_formula = Or(inner_formula, layers[-i])
        self.inner_formula = inner_formula
        self.formula_list = nn.ModuleList(layers)


class Learner(LearnerBase):
    """
    This class builds a computation graph that represents the
    neural ILP model and handles related graph running acitivies,
    including update, predict, and get_attentions for given queries.

    Args:
        option: hyper-parameters
    """

    def __init__(self, option):
        super().__init__()
        self.seed = option.seed
        self.num_step = option.num_step
        self.rank = option.rank
        self.num_layer = option.num_layer
        self.rnn_state_size = option.rnn_state_size

        self.norm = not option.no_norm
        self.thr = option.thr
        self.dropout = option.dropout
        self.learning_rate = option.learning_rate
        self.accuracy = option.accuracy
        self.top_k = option.top_k

        self.num_entity = option.num_entity
        self.num_operator = option.num_operator
        self.query_is_language = option.query_is_language

        if not option.query_is_language:
            self.num_query = option.num_query
            self.query_embed_size = option.query_embed_size
        else:
            self.vocab_embed_size = option.vocab_embed_size
            self.query_embed_size = self.vocab_embed_size
            self.num_vocab = option.num_vocab
            self.num_word = option.num_word

        set_seed(self.seed)
        self.headwise = option.headwise
        self.no_rev_in_model = option.no_rev_in_model
        self.num_relation = option.num_relation
        self.width = option.width
        self.depth = option.depth
        self.init_var = option.init_var
        self.model = option.model
        if self.model.startswith("drum_logic_fn_chaining"):
            model_class = {
                "1": UnifiedChain1,
                "2": UnifiedChain2,
                "3": UnifiedChain3,
                "4": UnifiedChain4,
            }[self.model.split("chaining")[-1]]
            self.logic_fns = nn.ModuleList([
                model_class(
                    self.rank, self.width, self.depth,
                    list(range(self.num_relation+1)), self.init_var, 1
                )
                for i in range(self.num_relation)
            ])
        elif self.model != "drum_logic_fn":
            raise NotImplementedError()
        elif not self.no_rev_in_model:
            self.logic_fns = nn.ModuleList([DRUMFuntion(
                self.num_step-1, self.rank, list(range(self.num_operator+1)),
                6./((self.num_step-1)*(self.num_operator+1)*self.rank), 1
            ) for _i in range(self.num_operator)])
        else:
            self.logic_fns = nn.ModuleList([DRUMFuntion(
                self.num_step-1, self.rank, list(range(self.num_relation+1)),
                6./((self.num_step-1)*(self.num_relation+1)*self.rank), 1
            ) for _i in range(self.num_relation)])
        self.soft_logic = option.soft_logic
        self.sparse = option.sparse

        set_seed(self.seed)
        self.running_mean = {
            split: {
                cat: RunningMean(0.97)
                for cat in ("loss", "in_top")
            } for split in ("train", "test")
        }

    def _random_uniform_unit(self, r, c):
        """ Initialize random and unit row norm matrix of size (r, c). """
        bound = 6. / np.sqrt(c)
        init_matrix = np.random.uniform(-bound, bound, (r, c))
        init_matrix = np.array(list(map(
            lambda row: row / np.linalg.norm(row), init_matrix)))
        # print('init',init_matrix)
        init_matrix = torch.tensor(init_matrix, dtype=torch.float32)
        return init_matrix

    def _clip_if_not_None(self, g, v, low, high):
        """ Clip not-None gradients to (low, high). """
        """ Gradient of T is None if T not connected to the objective. """
        if g is not None:
            # return (tf.clip_by_value(g, low, high), v)
            return (torch.clamp(g, low, high), v)
        else:
            return (g, v)

    def _inner_run_graph(self, queries, heads, tails, database):
        # print("start training", time.ctime())
        targets = F.one_hot(heads, num_classes=self.num_entity)

        def _dense(tensor):
            if self.sparse:
                return tensor
            else:
                return tensor.to_dense()

        def _sparse(tensor):
            if self.sparse:
                return tensor.to_sparse()
            else:
                return tensor

        def _diag(tensor):
            if tensor.is_sparse:
                assert tensor.ndim == 1
                tensor = tensor.coalesce()
                return torch.sparse_coo_tensor(
                    [tensor.indices().tolist()[0]] * 2,
                    tensor.values(),
                    (tensor.shape[0],) * 2
                )
            else:
                return tensor.diag()

        # if self.no_rev_in_model:
        data_matrix = torch.stack(
            [
                _dense(database[_i])
                for _i in range(self.num_relation)
            ] + [
                _diag(_sparse(torch.ones_like(database[0][0].to_dense())))
            ]
        )
        # else:
        #     data_matrix = torch.stack(
        #         [
        #             _dense(database[_i])
        #             for _i in range(self.num_operator // 2)
        #         ] + [
        #             _dense(database[_i]).transpose(0, 1)
        #             for _i in range(self.num_operator // 2)
        #         ] + [
        #             _diag(_sparse(torch.ones_like(database[0][0].to_dense())))
        #         ]
        #     )

        batch = (None, None,
                 data_matrix[None] if self.sparse else
                 prob_to_logit(data_matrix[None]),
                 None,
                 # torch.ones_like(data_matrix, dtype=int)[None],
                 None, list(range(self.num_operator+1)),)
        if not self.headwise:
            predicted_matrix = []
            for _tail, _query in zip(tails, queries[:, 0]):
                score_tensor, _, _ = self.logic_fns[_query].scoring(
                    batch, self.soft_logic)
                predicted_matrix.append(
                    score_tensor[0, _tail]
                )
            predictions = torch.stack(predicted_matrix)
        else:
            _query = queries[0, 0]
            score_tensor, _, _ = self.logic_fns[
                _query % len(self.logic_fns)].scoring(batch, self.soft_logic)
            if self.sparse:
                predictions = []
                score_tensor = squeeze(score_tensor, 0)
                for _ind, _q in enumerate(queries[:, 0]):
                    if _q == _query:
                        _pred = sparse_select(score_tensor, 1, tails[_ind])
                    else:
                        _pred = sparse_select(score_tensor, 0, tails[_ind])
                    _pred = sparse_scalarmul(
                        _pred, 1/max(self.thr, torch.sparse.sum(_pred)))
                    predictions.append(_pred)
                predictions = torch.stack(predictions)  # .to_dense()
            else:
                pos_ind = queries[:, 0] == _query
                neg_ind = queries[:, 0] != _query
                predictions = torch.zeros_like(score_tensor[0, queries[:, 0]])
                predictions[pos_ind] += \
                    score_tensor[0, :, tails[pos_ind]].transpose(0, 1)
                predictions[neg_ind] += score_tensor[0, tails[neg_ind]]
        if self.norm and not self.sparse:
            predictions /= predictions.sum(1).unsqueeze(1)

        if not self.sparse:
            final_loss = - torch.sum(
                targets * predictions.clamp_min(self.thr).log(), 1)
        else:
            predictions = predictions.coalesce()
            final_loss = - torch.sparse.sum(
                targets.to_sparse() * torch.sparse_coo_tensor(
                    predictions.indices(),
                    predictions.values().clamp_min(self.thr).log(),
                    predictions.shape
                ), 1
            ).to_dense()
            predictions = predictions.to_dense()

        if not self.accuracy:
            topk = predictions.topk(self.top_k)[1]
            in_top = (heads.unsqueeze(1) == topk).any(dim=1)
        else:
            _, indices = predictions.topk(self.top_k, sorted=False)[1]
            in_top = torch.squeeze(indices) == heads

        results = {
            "final_loss": final_loss,
            "in_top": in_top,
            "predictions": predictions,
            # "attention_operators": attention_operators_list,
            # "vocab_embedding": self.vocab_embedding,
        }
        return results

    def update(self, qq, hh, tt, mdb):
        results = self._run_graph(qq, hh, tt, mdb)
        return results["final_loss"], results["in_top"]

    def predict(self, qq, hh, tt, mdb):
        with torch.no_grad():
            results = self._run_graph(qq, hh, tt, mdb)
            return results["final_loss"], results["in_top"]

    def get_predictions_given_queries(self, qq, hh, tt, mdb):
        with torch.no_grad():
            results = self._run_graph(qq, hh, tt, mdb)
            return results["in_top"], results["predictions"]

    def get_attentions_given_queries(self, queries):
        with torch.no_grad():
            qq = queries
            hh = [0] * len(queries)
            tt = [0] * len(queries)
            mdb = {r: ([(0, 0)], [0.], (self.num_entity, self.num_entity))
                   for r in range(self.num_operator // 2)}
            results = self._run_graph(qq, hh, tt, mdb)
            return results["attention_operators"]

    def get_vocab_embedding(self):
        return self.vocab_embedding.weight
