import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import numpy as np
# import tensorflow as tf
# import tensorflow.compat.v1 as tf
# tf.disable_v2_behavior()
from core.utils import set_seed, RunningMean
from model_base import LearnerBase


class Learner(LearnerBase):
    """
    This class builds a computation graph that represents the
    neural ILP model and handles related graph running acitivies,
    including update, predict, and get_attentions for given queries.

    Args:
        option: hyper-parameters
    """

    def __init__(self, option):
        super().__init__()
        self.seed = option.seed
        self.num_step = option.num_step
        self.rank = option.rank
        self.num_layer = option.num_layer
        self.rnn_state_size = option.rnn_state_size

        self.norm = not option.no_norm
        self.thr = option.thr
        self.dropout = option.dropout
        self.learning_rate = option.learning_rate
        self.accuracy = option.accuracy
        self.top_k = option.top_k

        self.num_entity = option.num_entity
        self.num_operator = option.num_operator
        self.query_is_language = option.query_is_language

        if not option.query_is_language:
            self.num_query = option.num_query
            self.query_embed_size = option.query_embed_size
        else:
            self.vocab_embed_size = option.vocab_embed_size
            self.query_embed_size = self.vocab_embed_size
            self.num_vocab = option.num_vocab
            self.num_word = option.num_word

        set_seed(self.seed)
        self._build_parameters()
        self.running_mean = {
            split: {
                cat: RunningMean(0.97)
                for cat in ("loss", "in_top")
            } for split in ("train", "test")
        }

    def _random_uniform_unit(self, r, c):
        """ Initialize random and unit row norm matrix of size (r, c). """
        bound = 6. / np.sqrt(c)
        init_matrix = np.random.uniform(-bound, bound, (r, c))
        init_matrix = np.array(list(map(
            lambda row: row / np.linalg.norm(row), init_matrix)))
        # print('init',init_matrix)
        init_matrix = torch.tensor(init_matrix, dtype=torch.float32)
        return init_matrix

    def _clip_if_not_None(self, g, v, low, high):
        """ Clip not-None gradients to (low, high). """
        """ Gradient of T is None if T not connected to the objective. """
        if g is not None:
            # return (tf.clip_by_value(g, low, high), v)
            return (torch.clamp(g, low, high), v)
        else:
            return (g, v)

    def _build_parameters(self):
        if not self.query_is_language:
            query_embedding_params = self._random_uniform_unit(
                self.num_query + 1,  # <END> token
                self.query_embed_size)
            self.query_embedding = nn.Embedding(
                self.num_query + 1,
                self.query_embed_size,
                _weight=query_embedding_params
            )
        else:
            vocab_embedding_params = self._random_uniform_unit(
                self.num_vocab + 1,  # <END> token
                self.vocab_embed_size)
            self.vocab_embedding = nn.Embedding(
                self.num_query + 1,
                self.vocab_embed_size,
                _weight=vocab_embedding_params
            )

        cells = []
        # self.cells_bw = []
        for i in range(self.rank):
            lstm = torch.nn.LSTM(
                self.query_embed_size, self.rnn_state_size,
                self.num_layer, bidirectional=True
            )
            cells.append(lstm)
            # lstm_bw = torch.nn.LSTM(
            #     self.query_embed_size, self.rnn_state_size,
            #     self.num_layer
            # )
            # self.cells_bw.append(lstm_bw)
        self.cells = nn.ModuleList(cells)

        self.W_0 = Parameter(
            torch.tensor(
                np.random.randn(self.rnn_state_size * 2,
                                self.num_operator + 1),
                dtype=torch.float32),
        )
        self.b_0 = Parameter(
            torch.zeros((1, self.num_operator + 1), dtype=torch.float32),
        )

    def _inner_run_graph(self, queries, heads, tails, database):
        # self.tails = tf.placeholder(tf.int32, [None])
        # self.heads = tf.placeholder(tf.int32, [None])
        targets = F.one_hot(heads, num_classes=self.num_entity)

        if not self.query_is_language:
            # self.queries = tf.placeholder(tf.int32, [None, self.num_step])
            rnn_inputs = self.query_embedding(queries)
        else:
            # self.queries = tf.placeholder(
            #     tf.int32, [None, self.num_step, self.num_word])
            embedded_query = self.vocab_embedding(queries)
            rnn_inputs = torch.mean(embedded_query, dim=2)

        # rnn_inputs = [q.reshape([-1, self.query_embed_size])
        #               for q in torch.split(rnn_inputs, 1, dim=1)]
        rnn_inputs_new = rnn_inputs.transpose(0, 1)[:self.num_step - 1]
        # print(len(rnn_inputs_new))
        rnn_outputs_list = []

        for i in range(self.rank):
            rnn_outputs, _ = self.cells[i](rnn_inputs_new)
            rnn_outputs_list.append(rnn_outputs)

        attention_operators_list = []
        memories_list = []
        for i in range(self.rank):
            attention_operators_list.append([
                torch.split(
                    torch.softmax(
                        torch.matmul(rnn_output, self.W_0) + self.b_0, -1
                    ), 1, dim=1)
                for rnn_output in rnn_outputs_list[i]
            ])
            memories_list.append(
                F.one_hot(tails, num_classes=self.num_entity,
                          ).float().unsqueeze(1)
            )

        # Get predictions
        predictions = 0.0
        for i_rank in range(self.rank):
            for t in range(self.num_step):
                # memory_read: tensor of size (batch_size, num_entity)
                # memory_read = tf.squeeze(self.memories, squeeze_dims=[1])
                memory_read = memories_list[i_rank][:, -1, :]

                if t < self.num_step - 1:
                    # database_results: (will be) a list of num_operator tensor
                    # each of size (batch_size, num_entity).
                    database_results = []
                    memory_read = memory_read.transpose(0, 1)
                    for r in range(self.num_operator//2):
                        for op_matrix, op_attn in zip(
                            [database[r], database[r].transpose(0, 1)],
                            [attention_operators_list[i_rank][t][r],
                             attention_operators_list[i_rank][t][
                                 r+self.num_operator//2
                             ]]):
                            product = op_matrix.mm(memory_read)
                            database_results.append(
                                product.transpose(0, 1) * op_attn)
                    database_results.append(
                        memory_read.transpose(0, 1) *
                        attention_operators_list[i_rank][t][-1])

                    added_database_results = sum(database_results)
                    if self.norm:
                        added_database_results /= added_database_results.sum(
                            dim=1, keepdim=True).clamp_min(self.thr)
                    if self.dropout > 0.:
                        dropout_layer = nn.Dropout(self.dropout)
                        added_database_results = dropout_layer(
                            added_database_results)

                    # Populate a new cell in memory by concatenating.
                    memories_list[i_rank] = torch.cat(
                        [memories_list[i_rank],
                         added_database_results.unsqueeze(1)],
                        dim=1)
                else:
                    predictions += memory_read

        # print(self.rank)

        final_loss = - torch.sum(
            targets * predictions.clamp_min(self.thr).log(), 1)

        if not self.accuracy:
            topk = predictions.topk(self.top_k)[1]
            in_top = (heads.unsqueeze(1) == topk).any(dim=1)
        else:
            _, indices = predictions.topk(self.top_k, sorted=False)[1]
            in_top = torch.squeeze(indices) == heads

        results = {
            "final_loss": final_loss,
            "in_top": in_top,
            "predictions": predictions,
            "attention_operators": attention_operators_list,
            # "vocab_embedding": self.vocab_embedding,
        }
        return results

    def update(self, qq, hh, tt, mdb):
        results = self._run_graph(qq, hh, tt, mdb)
        self.running_mean["train"]["loss"].update(results["final_loss"].mean())
        self.running_mean["train"]["in_top"].update(results["in_top"].float().mean())
        return results["final_loss"], results["in_top"]

    def predict(self, qq, hh, tt, mdb):
        with torch.no_grad():
            results = self._run_graph(qq, hh, tt, mdb)
        self.running_mean["test"]["loss"].update(results["final_loss"].mean())
        self.running_mean["test"]["in_top"].update(results["in_top"].float().mean())
        return results["final_loss"], results["in_top"]

    def get_predictions_given_queries(self, qq, hh, tt, mdb):
        with torch.no_grad():
            results = self._run_graph(qq, hh, tt, mdb)
            return results["in_top"], results["predictions"]

    def get_attentions_given_queries(self, queries):
        with torch.no_grad():
            qq = queries
            hh = [0] * len(queries)
            tt = [0] * len(queries)
            mdb = {r: ([(0, 0)], [0.], (self.num_entity, self.num_entity))
                   for r in range(self.num_operator // 2)}
            results = self._run_graph(qq, hh, tt, mdb)
            return results["attention_operators"]

    def get_vocab_embedding(self):
        return self.vocab_embedding.weight
        # qq = [[0] * self.num_word]
        # hh = [0] * len(qq)
        # tt = [0] * len(hh)
        # mdb = {r: ([(0, 0)], [0.], (self.num_entity, self.num_entity))
        #        for r in range(self.num_operator // 2)}
        # results = self._run_graph(qq, hh, tt, mdb)
        # return results["vocab_embedding"]
