#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CDialog GPT2 Reader."""

import numpy as np

from knover.data.dialog_reader import DialogReader
from knover.data.diamante_reader import DiamanteReader
from knover.utils import mask, pad_batch_data


class CDialGPT2Reader(DiamanteReader):
    """The implement of CDialGPT2Reader."""

    @classmethod
    def add_cmdline_args(cls, parser):
        group = DiamanteReader.add_cmdline_args(parser)
        return group

    def __init__(self, args):

        super(CDialGPT2Reader, self).__init__(args)

    def _parse_src(self, src):
        """Parse source sequence and return corresponding fields."""

        # tokenize src
        s_token_ids_list = []
        for s in src.split("[SEP]"):
            s = s.strip()
            if self.data_format == "tokenized":
                s_tokens = s.split(" ")
            else:
                s_tokens = self.tokenizer.tokenize(s)

            s_token_ids = self.tokenizer.convert_tokens_to_ids(s_tokens) # strip eos_id
            s_token_ids_list.append(s_token_ids)

        # trim src
        idx = len(s_token_ids_list) - 1
        total_token_num = 1
        while idx >= 0:
            total_token_num += len(s_token_ids_list[idx])
            if total_token_num > self.max_src_len:
                if self.truncate_first_turn and idx == 0:
                    truncated_ids = s_token_ids_list[idx][:self.max_src_len - total_token_num]
                    if len(truncated_ids) > 1:
                        s_token_ids_list[idx] = truncated_ids[:-1]
                        idx -= 1
                break
            idx -= 1

        s_token_ids_list = s_token_ids_list[idx + 1:]
        speaker_id = 1 if len(s_token_ids_list) % 2 == 0 else 0

        src_token_ids = []
        src_type_ids = []
        for utt_ids in s_token_ids_list:
            src_token_ids = src_token_ids + [speaker_id + 13086] + utt_ids
            src_type_ids += [speaker_id + 13086] * (len(utt_ids) + 1)

            speaker_id = abs(int(speaker_id-1))
    
        src_token_ids = [self.bos_id] + src_token_ids
        src_type_ids = [self.bos_id] + src_type_ids

        field_values = {
            "token_ids": src_token_ids,
            "type_ids": src_type_ids,
            "pos_ids": list(range(len(src_token_ids)))
        }
        
        for k in field_values:
            assert len(field_values[k]) == len(field_values["token_ids"]), \
                f"len(field_values[{k}]) != len(field_values['token_ids'])"
        return field_values

    def _parse_tgt(self, tgt):
        """Parse target sequence and return corresponding fields."""
        # process tgt
        tgt = tgt.strip()
        if self.data_format == "tokenized":
            tgt_tokens = tgt.split(" ")
        else:
            tgt_tokens = self.tokenizer.tokenize(tgt)

        tgt_token_ids = self.tokenizer.convert_tokens_to_ids(tgt_tokens)

        # trim tgt
        tgt_token_ids = [13087] + tgt_token_ids[:self.max_tgt_len - 2] + [self.eos_id]

        field_values = {
            "token_ids": tgt_token_ids,
            "type_ids": [13087] * len(tgt_token_ids),
            "pos_ids": list(range(len(tgt_token_ids)))
        }
        
        return field_values

    def _pad_batch_records(self, batch_records, is_infer, phase=None):

        batch_size = len(batch_records)
        batch = {}
        batch_token_ids = [record.token_ids for record in batch_records]
        batch_type_ids = [record.type_ids for record in batch_records]
        batch_pos_ids = [record.pos_ids for record in batch_records]

        batch["token_ids"] = pad_batch_data(batch_token_ids, pad_id=self.pad_id)
        batch["type_ids"] = pad_batch_data(batch_type_ids, pad_id=self.pad_id)
        batch["pos_ids"] = pad_batch_data(batch_pos_ids, pad_id=self.pad_id)
        
        batch_label = [record.label for record in batch_records]
        batch_tgt_start_idx = [record.tgt_start_idx for record in batch_records]
        
        batch["generation_mask"] = self._gen_self_attn_mask(batch_token_ids, is_unidirectional=True)

        if is_infer:
            if self.do_generation:
                tgt_ids = np.array([[[13087]]] * len(batch_token_ids), dtype="int64")
                if self.position_style == "continuous":
                    tgt_pos = np.array(batch_tgt_start_idx, dtype="int64")
                else:
                    tgt_pos = np.zeros_like(batch_tgt_start_idx, dtype="int64")
                tgt_pos = tgt_pos.reshape(-1, 1, 1)
                batch["init_score"] = np.zeros_like(tgt_ids, dtype="float32").reshape(-1, 1).tolist()
                batch["tgt_ids"] = tgt_ids.tolist()
                batch["tgt_pos"] = tgt_pos.tolist()
                batch["parent_idx"] = np.array(range(batch_size), dtype="int32")

                generation_mask = self._gen_self_attn_mask(
                    batch_token_ids,
                    batch_tgt_start_idx=batch_tgt_start_idx)
                batch["tgt_generation_mask"] = generation_mask[:, 0:1, :].astype("float32")
            else:
                # ppl / ranking score
                batch["tgt_label"], batch["tgt_idx"], batch["label_idx"] = mask(
                    batch_tokens=batch_token_ids,
                    vocab_size=self.vocab_size,
                    bos_id=self.bos_id,
                    tgt_starts=batch_tgt_start_idx,
                    labels=[1]*len(batch_label),
                    is_unidirectional=True)

            batch_data_id = [record.data_id for record in batch_records]
            batch["data_id"] = np.array(batch_data_id).astype("int64").reshape([-1, 1])
        else:
            batch["tgt_label"], batch["tgt_idx"], batch["label_idx"] = mask(
                batch_tokens=batch_token_ids,
                vocab_size=self.vocab_size,
                bos_id=self.bos_id,
                tgt_starts=batch_tgt_start_idx,
                labels=batch_label,
                is_unidirectional=True)
            
            batch["label"] = batch_label
        
        return batch

