# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

"""Pretrain GPT"""

import torch
from functools import partial, reduce
import sys, os

sys.path.append(os.path.abspath(os.path.join(
    os.path.join(os.path.dirname(__file__), "../../../"))))
from megatron.training import get_args, get_retro_args
from megatron.training import print_rank_0
from megatron.training import get_timers
from megatron.training import get_tokenizer
from megatron.core import tensor_parallel
from megatron.core.enums import ModelType
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.utils import get_blend_from_list
from megatron.training import pretrain
from megatron.training.utils import get_ltor_masks_and_position_ids
from megatron.training.utils import average_losses_across_data_parallel_group
from pretrain_gpt import is_dataset_built_on_rank
from model_provider import model_provider
from gpt_builders import gpt_builder
from tools.retro.sft.dataset_conv import JsonQADataset, JsonQADatasetConfig, RetroJsonQADataset, RetroJsonQADatasetConfig


def get_tasks_args(parser):
    """Provide extra arguments required for tasks."""
    group = parser.add_argument_group(title='tasks')

    # parameters for the knowledgeable dialogue generation
    group.add_argument('--task', type=str, default=None,
                       help='Task name.')
    group.add_argument('--epochs', type=int, default=None,
                       help='Number of finetunning epochs. Zero results in '
                            'evaluation only.')
    group.add_argument('--keep-last', action='store_true',
                       help='Keep the last batch (maybe incomplete) in'
                            'the data loader')
    group.add_argument('--pretrained-checkpoint', type=str, default=None,
                       help='Pretrained checkpoint used for finetunning.')
    group.add_argument('--data-folder', type=str, default=None,
                       help='dataset folder')
    group.add_argument('--answer-loss-only', action='store_true', default=False,
                       help='take the loss from answer part, ignore the context')
    group.add_argument('--weight', type=float, default=1)
    group.add_argument('--adaptor', action='store_true', default=False)
    group.add_argument('--project-size', type=int, default=256)
    group.add_argument('--cyclic-train-iters', type=int, default=None)
    group.add_argument('--stored_params', type=dict, default=dict())
    group.add_argument('--eval_ppl', action='store_true', default=False)
    group.add_argument('--debug', action='store_true', default=False)
    group.add_argument('--add_retriever', action='store_true', default=False)
    group.add_argument('--return_doc_ids', action='store_true', default=False)
    group.add_argument('--return_neighbor_ids', action='store_true', default=False)
    group.add_argument('--add_offset_doc_ids', action='store_true', default=False)
    group.add_argument('--offset_dict_path', type=str, default='')
    group.add_argument('--neighbors_path', type=str, default='')
    group.add_argument('--valid_neighbors_path', type=str, default='')
    group.add_argument('--database_path', type=str, default='')
    group.add_argument('--valid_database_path', type=str, default='')
    group.add_argument('--encoder-layers', type=int, default=12)
    group.add_argument('--encoder-hidden-dropout', type=float, default=0.1)
    group.add_argument('--encoder-attention-dropout', type=float, default=0.1)
    group.add_argument('--k', type=int, default=2)
    group.add_argument('--r', type=int, default=128)
    group.add_argument('--m', type=int, default=64)
    group.add_argument('--dpr-mode', type=str, default="multi")
    group.add_argument('--faiss-ckpt', type=str, default='')
    group.add_argument('--original-db-file', type=str, default="")
    group.add_argument('--ft_neighbours', type=int, default=1)
    group.add_argument('--reuse-top', action='store_true', default=False)
    group.add_argument('--shuffle_topn', action='store_true', default=False)
    group.add_argument('--chunk0', action='store_true', default=False)
    group.add_argument('--disable-encoder', action='store_true', default=False)
    group.add_argument('--qa-space-pad', action='store_true', default=False)
    group.add_argument('--retro-mask-encoder', action='store_true', default=False)
    group.add_argument('--without-title', action='store_true', default=False)
    group.add_argument('--longform-answer', action='store_true', default=False)
    group.add_argument('--bert-retriever-neighbours', action='store_true', default=False)
    group.add_argument('--prefix', action='store_true', default=False)
    group.add_argument('--question-in-encoder', action='store_true', default=False)
    group.add_argument('--reset_eval', type=bool, default=True)  ## by default reset eval for each eval
    return parser


def get_batch(data_iterator):
    """Generate a batch"""
    args = get_args()
    tokenizer = get_tokenizer()

    # Items and their type.
    keys = ['text', 'answer_mask']
    datatype = torch.int64

    if args.retro_add_retriever:
        keys += 'neighbor_tokens', 'context_len'

    # Broadcast data.
    if data_iterator is not None:
        try:
            data = next(data_iterator)

        except Exception:
            data = data_iterator
            raise ValueError("error with data_iterator")
    else:
        data = None

    data_b = tensor_parallel.broadcast_data(keys, data, datatype)
    chunk_size = torch.min(data_b['context_len'])
    retro_args = get_retro_args()
    # two chunk retro has at least seq_len / 2 of chunk size
    retro_args.retro_gpt_chunk_length = max(args.seq_length // 2, args.seq_length - chunk_size.item())

    # Unpack.
    tokens_ = data_b['text'].long()
    labels = tokens_[:, 1:].contiguous()
    tokens = tokens_[:, :-1].contiguous()

    answer_mask = data_b["answer_mask"].float()[:, 1:].contiguous()

    if args.retro_add_retriever:
        neighbor_tokens = data_b['neighbor_tokens'].view(-1,
                                                         retro_args.retro_gpt_retrieved_length).long()  # [bs * l * k, r]

    # Get the masks and postition ids.
    attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
        tokens,
        tokenizer.eod,
        args.reset_position_ids,
        args.reset_attention_mask,
        args.eod_mask_loss)

    if args.answer_loss_only:
        loss_mask = loss_mask * answer_mask

    if args.retro_add_retriever:
        _, _, neighbor_position_ids = get_ltor_masks_and_position_ids(
            neighbor_tokens,
            tokenizer.eod,
            args.reset_position_ids,
            args.reset_attention_mask,
            args.eod_mask_loss)
        neighbor_attention_mask = None
        return tokens, labels, loss_mask, attention_mask, position_ids, \
            neighbor_tokens, neighbor_attention_mask, neighbor_position_ids
    else:
        return tokens, labels, loss_mask, attention_mask, position_ids


def loss_func(loss_mask, output_tensor):
    losses = output_tensor.float()
    loss_mask = loss_mask.view(-1).float()
    loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

    # Reduce loss for logging.
    averaged_loss = average_losses_across_data_parallel_group([loss])

    return loss, {'lm loss': averaged_loss[0]}


def forward_step(data_iterator, model):
    """Forward step."""
    args = get_args()
    timers = get_timers()

    if args.retro_add_retriever:
        timers('batch-generator', log_level=2).start()
        tokens, labels, loss_mask, attention_mask, position_ids, \
            neighbor_tokens, neighbor_attention_mask, neighbor_position_ids = get_batch(
            data_iterator)
        timers('batch-generator').stop()
        output_tensor = model(tokens, position_ids, attention_mask,
                              retriever_input_ids=neighbor_tokens,
                              retriever_position_ids=neighbor_position_ids,
                              retriever_attn_mask=neighbor_attention_mask,
                              labels=labels)
    else:
        timers('batch-generator', log_level=2).start()
        tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
            data_iterator)
        timers('batch-generator').stop()
        output_tensor = model(tokens, position_ids, attention_mask,
                              labels=labels)

    return output_tensor, partial(loss_func, loss_mask)


def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid, and test datasets."""
    args = get_args()
    retro_args = get_retro_args()

    tokenizer = get_tokenizer()

    def fix_and_split_blend_pair(pair):
        weight, name = pair
        return [
            [weight, os.path.join(args.data_folder, name, f"{name}_QA_train.json")],
            [weight, os.path.join(args.data_folder, name, f"{name}_QA_dev.json")],
            None,
        ]

    blend = [args.data_path[i:i+2] for i in range(0, len(args.data_path), 2)]

    if len(blend) == 1:
        blend_per_split =  [
            os.path.join(args.data_folder, blend[0], f"{blend[0]}_QA_train.json"),
            os.path.join(args.data_folder, blend[0], f"{blend[0]}_QA_dev.json"),
            None,
        ]
    else:
        blend_per_split = [
            list(
                reduce(
                    lambda x, y: x + y,
                    list(zip(*map(fix_and_split_blend_pair, blend)))[0]
                )
            ),
            None,
            None,
        ]

    blend_per_split = [get_blend_from_list(blend) for blend in blend_per_split]

    extra_kwargs = {}

    if args.retro_add_retriever:
        dataset_cls = RetroJsonQADataset
        config_cls = RetroJsonQADatasetConfig
        extra_kwargs["retro_num_neighbors"] = args.retro_num_neighbors
        extra_kwargs["retro_gpt_retrieved_length"] = retro_args.retro_gpt_retrieved_length
    else:
        dataset_cls = JsonQADataset
        config_cls = JsonQADatasetConfig

    config = config_cls(
        random_seed=args.seed,
        sequence_length=args.seq_length,
        blend_per_split=blend_per_split,
        split=args.split,
        path_to_cache=args.data_cache_path,
        tokenizer=tokenizer,
        ft_neighbours=args.ft_neighbours,
        bert_retriever_neighbours=args.bert_retriever_neighbours,
        longform_answer=args.longform_answer,
        inference_only=False,
        retrieved_neighbours=False,
        fix_newsqa=True,
        mid_level_dataset_surplus=args.mid_level_dataset_surplus,
        **extra_kwargs
    )

    print_rank_0('> building train, validation, and test datasets '
                 'for GPT ...')
    train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
        dataset_cls,
        train_val_test_num_samples,
        is_dataset_built_on_rank,
        config
    ).build()
    print_rank_0("> finished creating GPT datasets ...")

    return train_ds, valid_ds, test_ds


if __name__ == "__main__":

    # Temporary for transition to core datasets
    train_valid_test_datasets_provider.is_distributed = True

    pretrain(train_valid_test_datasets_provider, partial(model_provider, gpt_builder),
        ModelType.retro_decoder,  # ModelType.encoder_or_decoder,
        forward_step,
        extra_args_provider=get_tasks_args
    )
