#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CDial GPT2 model."""

import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers

from knover.models import register_model
from knover.models.unified_transformer import UnifiedTransformer
from knover.models.diamante import Diamante


@register_model("CDialGPT2")
class CDialGPT2(Diamante):
    """GPT2 which doesn't has token_type_emb and mask_lm_trans_fc."""

    @classmethod
    def add_cmdline_args(cls, parser):
        """Add cmdline arguments."""
        group = Diamante.add_cmdline_args(parser)
        return group

    def __init__(self, args, place):

        super(CDialGPT2, self).__init__(args, place)

    def _gen_input(self,
                   token_ids,
                   type_ids,
                   pos_ids,
                   role_ids,
                   turn_ids,
                   input_mask,
                   aux_emb=None,
                   name=""):

        token_emb_out = layers.embedding(
            input=token_ids,
            size=[self.vocab_size, self.emb_size],
            dtype=self.dtype,
            param_attr=fluid.ParamAttr(
                name=name + self.token_emb_name, initializer=self.param_initializer))

        type_emb_out = layers.embedding(
            input=type_ids,
            size=[self.vocab_size, self.emb_size],
            dtype=self.dtype,
            param_attr=fluid.ParamAttr(
                name=name + self.token_emb_name, initializer=self.param_initializer))

        pos_emb_out = layers.embedding(
            input=pos_ids,
            size=[self.max_position_seq_len, self.emb_size],
            dtype=self.dtype,
            param_attr=fluid.ParamAttr(
                name=name + self.pos_emb_name, initializer=self.param_initializer))
        # gpt2 has no token_type_emb.
        emb_out = token_emb_out + pos_emb_out + type_emb_out

        # generate n-head self-attention mask
        if isinstance(input_mask, (tuple, list)):
            attn_bias = (
                layers.unsqueeze(
                    layers.scale(x=input_mask, scale=1e4, bias=-1.0, bias_after_scale=False), 1)
                for mask in input_mask
            )
            for bias in attn_bias:
                bias.stop_gradient = True
        else:
            attn_bias = layers.unsqueeze(
                layers.scale(x=input_mask, scale=1e4, bias=-1.0, bias_after_scale=False), 1)
            attn_bias.stop_gradient = True

        return emb_out, attn_bias

    def _calc_logits(self, enc_out, tgt_idx=None, name=""):

        if tgt_idx is None:
            seq_feat = layers.reshape(x=enc_out, shape=[-1, self.hidden_size])
        elif len(tgt_idx.shape) == 2 and tgt_idx.shape[1] == 2:
            seq_feat = layers.gather_nd(input=enc_out, index=tgt_idx)
        else:
            raise ValueError(f"Invalid indices shape {tgt_idx.shape} is used")

        seq_trans_feat = seq_feat
        if self.weight_sharing:
            logits = layers.matmul(
                x=seq_trans_feat,
                y=fluid.default_main_program().global_block().var(
                    name + self.token_emb_name),
                transpose_y=True)
            if self.cls_bias:
                logits += layers.create_parameter(
                    shape=[self.vocab_size],
                    dtype=self.dtype,
                    attr=fluid.ParamAttr(name="mask_lm_out_fc.b_0"),
                    is_bias=True)
        else:
            seq_out_bias_attr = "mask_lm_out_fc.b_0" if self.cls_bias else False
            logits = layers.fc(
                input=seq_trans_feat,
                size=self.vocab_size,
                param_attr=fluid.ParamAttr(
                    name="mask_lm_out_fc.w_0",
                    initializer=self.param_initializer),
                    bias_attr=seq_out_bias_attr)
        return logits

    def infer(self, inputs, outputs):
        """Run model inference.
        fix type_ids 
        """
        if self.do_generation:
            predictions = self.generator.inference(self, inputs, outputs)
            
            model_input = {}

            generation_mask = predictions["generation_mask"]

            append_mask = layers.fill_constant_batch_size_like(generation_mask, [-1, 1], "float32", 1)
            append_mask = layers.unsqueeze(append_mask, [2])
            
            generation_mask = layers.concat([generation_mask, append_mask], axis=2)
            model_input["generation_mask"] = generation_mask

            model_input["pos_ids"] = predictions["pos_ids"]

            model_input["token_ids"] = \
                layers.fill_constant_batch_size_like(
                    model_input["generation_mask"], [-1, 1, 1], "int64", self.generator.eos_id)

            model_input["type_ids"] = \
                layers.fill_constant_batch_size_like(model_input["generation_mask"], [-1, 1, 1], "int64", 13087)
            
            feat, _ = self._generation_network(**model_input)

            feat = feat[:, 0, :]
            ranking_score = self._get_similarity_score(feat)
            ranking_score = paddle.nn.functional.sigmoid(ranking_score)
            predictions["ranking_score"] = ranking_score

            return predictions
           
        else:
            predictions = super(Diamante, self).infer(inputs, outputs)
            # ranking score
            feat = layers.gather_nd(outputs["enc_out"], inputs["label_idx"])
            ranking_score = self._get_similarity_score(feat)
            predictions["ranking_score"] = paddle.nn.functional.sigmoid(ranking_score)

            return predictions

    def _prepare_timestep_input(self, state, step_idx):
        """
        fix type_ids
        """
        model_input, pre_ids, pre_scores = super(CDialGPT2, self)._prepare_timestep_input(state, step_idx)

        model_input["type_ids"] = layers.fill_constant_batch_size_like(model_input["generation_mask"], [-1, 1, 1], "int64", 13087)
        return model_input, pre_ids, pre_scores
