#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Diamante model."""

import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle

from knover.models import register_model
from knover.core.model import Model
from knover.models.unified_transformer import UnifiedTransformer
from knover.modules.transformer_block import encoder, pre_process_layer
from knover.utils import repeat_array_or_tensor
from knover.utils import str2bool


@register_model("Diamante")
class Diamante(UnifiedTransformer):
    """diamante model"""

    @classmethod
    def add_cmdline_args(cls, parser):
        """Add cmdline arguments."""
        group = UnifiedTransformer.add_cmdline_args(parser)
        return group

    def __init__(self, args, place):
        self.batch_size = args.batch_size
        super(Diamante, self).__init__(args, place)

    def _get_feed_dict(self, is_infer=False):
        """Get model's input feed dict.

        Args:
            is_infer: If true, get inference input feed dict, otherwise get training / evaluation input feed dict.

        Returns:
            feed_dict: A feed dict mapping keys to feed input variable.
        """
        if is_infer and self.do_generation:
            feed_dict = super(Diamante, self)._get_feed_dict(is_infer)
        else:
            feed_dict = {}
            feed_dict["token_ids"] = layers.data(name="token_ids", shape=[-1, self.max_seq_len, 1], dtype="int64")
            feed_dict["type_ids"] = layers.data(name="type_ids", shape=[-1, self.max_seq_len, 1], dtype="int64")
            feed_dict["pos_ids"] = layers.data(name="pos_ids", shape=[-1, self.max_seq_len, 1], dtype="int64")

            if self.use_role:
                feed_dict["role_ids"] = layers.data(name="role_ids", shape=[-1, self.max_seq_len, 1], dtype="int64")
            if self.use_turn:
                feed_dict["turn_ids"] = layers.data(name="turn_ids", shape=[-1, self.max_seq_len, 1], dtype="int64")

            feed_dict["generation_mask"] = \
                layers.data(name="generation_mask",
                            shape=[-1, self.max_seq_len, self.max_seq_len],
                            dtype=self.dtype)

            feed_dict["label_idx"] = layers.data(name="label_idx", shape=[-1, 2], dtype="int64")

            if is_infer:
                if not self.do_generation:
                    feed_dict["tgt_label"] = layers.data(name="tgt_label", shape=[-1, 1], dtype="int64")
                    feed_dict["tgt_idx"] = layers.data(name="tgt_idx", shape=[-1, 2], dtype="int64")
                feed_dict["data_id"] = layers.data(name="data_id", shape=[-1, 1], dtype="int64")
            else:
                feed_dict["label"] = layers.data(name="label", shape=[-1, 1], dtype="int64")
                feed_dict["tgt_label"] = layers.data(name="tgt_label", shape=[-1, 1], dtype="int64")
                feed_dict["tgt_idx"] = layers.data(name="tgt_idx", shape=[-1, 2], dtype="int64")
                
        return feed_dict

    def _get_similarity_score(self, feat):
        """Get similarity score.""" 
        similarity_score = layers.fc(
            input=feat,
            size=1,
            act=None,
            param_attr=fluid.ParamAttr(
                name="similarity_fc.w_0",
                initializer=self.param_initializer),
            bias_attr="similarity_fc.b_0")

        return similarity_score

    def get_metrics(self, inputs, outputs):
        """Get metrics."""

        # NLL loss
        metrics = super(Diamante, self).get_metrics(inputs, outputs)

        # Preference Estimation loss
        feat = layers.gather_nd(outputs["enc_out"], inputs["label_idx"])
        similarity_score = self._get_similarity_score(feat)
        
        # triple
        positive_index = [[0 + index, 0 + index, 1 + index] for index in range(0, self.batch_size, 3)]
        negative_index = [[[1 + index, 2 + index, 2 + index] for index in range(0, self.batch_size, 3)]]

        positive_index = fluid.layers.assign(np.array(positive_index).reshape(-1).astype("int32"))
        negative_index = fluid.layers.assign(np.array(negative_index).reshape(-1).astype("int32"))

        pos_sim = layers.gather(similarity_score, index=positive_index, overwrite=False)
        neg_sim = layers.gather(similarity_score, index=negative_index, overwrite=False)

        labels = layers.fill_constant_batch_size_like(pos_sim, [-1, 1], dtype=self.dtype, value=1)

        ob = paddle.nn.functional.sigmoid(pos_sim - neg_sim)
        pe_loss = paddle.nn.functional.log_loss(ob, labels)
        mean_pe_loss = layers.mean(pe_loss)
        metrics["mean_pe_loss"] = mean_pe_loss

        # joint loss
        loss = metrics["mean_pe_loss"] + metrics["token_lm_loss"]
    
        metrics["loss"] = loss
        metrics["score_gap"] = layers.mean(ob)
        metrics["acc"] = layers.mean(layers.cast(pos_sim > neg_sim, self.dtype))

        human_bot_index = [i for i in range(0, self.batch_size, 3)]
        human_random_index = [i for i in range(1, self.batch_size, 3)]
        bot_random_index = [i for i in range(2, self.batch_size, 3)]

        human_bot_index = fluid.layers.assign(np.array(human_bot_index).astype("int32"))
        human_random_index = fluid.layers.assign(np.array(human_random_index).astype("int32"))
        bot_random_index = fluid.layers.assign(np.array(bot_random_index).astype("int32"))

        metrics["human_bot_gap"] = layers.mean(layers.gather(ob, index=human_bot_index, overwrite=False))
        metrics["human_random_gap"] = layers.mean(layers.gather(ob, index=human_random_index, overwrite=False))
        metrics["bot_random_gap"] = layers.mean(layers.gather(ob, index=bot_random_index, overwrite=False))
        
        return metrics

    def infer(self, inputs, outputs):
        """Run model inference.

        Only support generation now.
        """
        if self.do_generation:
            predictions = self.generator.inference(self, inputs, outputs)
            
            model_input = {}

            generation_mask = predictions["generation_mask"]

            append_mask = layers.fill_constant_batch_size_like(generation_mask, [-1, 1], "float32", 1)
            append_mask = layers.unsqueeze(append_mask, [2])
            
            generation_mask = layers.concat([generation_mask, append_mask], axis=2)
            model_input["generation_mask"] = generation_mask

            model_input["pos_ids"] = predictions["pos_ids"]

            model_input["token_ids"] = \
                layers.fill_constant_batch_size_like(
                    model_input["generation_mask"], [-1, 1, 1], "int64", self.generator.eos_id)

            model_input["type_ids"] = \
                layers.fill_constant_batch_size_like(model_input["generation_mask"], [-1, 1, 1], "int64", 1)

            if self.use_role:
                model_input["role_ids"] = \
                    layers.fill_constant_batch_size_like(
                        model_input["generation_mask"], [-1, 1, 1], "int64", 0)

            feat, _ = self._generation_network(**model_input)

            feat = feat[:, 0, :]
            ranking_score = self._get_similarity_score(feat)
            ranking_score = paddle.nn.functional.sigmoid(ranking_score)
            predictions["ranking_score"] = ranking_score

            return predictions
           
        else:
            predictions = super(Diamante, self).infer(inputs, outputs)
            # ranking score
            feat = layers.gather_nd(outputs["enc_out"], inputs["label_idx"])
            ranking_score = self._get_similarity_score(feat)
            predictions["ranking_score"] = paddle.nn.functional.sigmoid(ranking_score)

            return predictions

    def _run_generation(self, inputs):
        """Run generation."""
        batch_size = self._get_batch_size(inputs)
        inputs["parent_idx"] = np.array(range(batch_size), dtype="int64")
        outputs = self._execute(
            self.infer_program,
            inputs,
            self.infer_fetch_dict,
            return_numpy=False)

        predictions = []
        data_id_list = np.array(outputs["data_id"]).reshape(-1).tolist()
        token_ids_list = np.array(outputs["token_ids"]).squeeze(2).tolist()
        ranking_score = np.array(outputs["ranking_score"]).tolist()

        seq_ids = outputs["finished_ids"]
        seq_ids_np  = np.array(outputs["finished_ids"])
        seq_scores_np = np.array(outputs["finished_scores"])
        for i, (data_id, token_ids, score) in enumerate(zip(data_id_list, token_ids_list, ranking_score)):
            start = seq_ids.lod()[0][i]
            end = seq_ids.lod()[0][i + 1]
            for j in range(start, end):
                sub_start = seq_ids.lod()[1][j]
                sub_end = seq_ids.lod()[1][j + 1]
                pred = {}
                pred["data_id"] = data_id
                pred["decode_score"] = float(seq_scores_np[sub_end - 1])
                pred["context_token_ids"] = token_ids
                pred["response_token_ids"] = seq_ids_np[sub_start:sub_end].tolist()
                pred["ranking_score"] = score[0]

                predictions.append(pred)

        return predictions
