# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

"""Sample Generate GPT"""
import torch
import os
import sys
from typing import Union

sys.path.append(os.path.abspath(os.path.join(
    os.path.join(os.path.dirname(__file__), "../../../"))))
from megatron.training import get_args, get_retro_args
from megatron.training import print_rank_0
from megatron.training import get_tokenizer
from megatron.training.checkpointing import load_checkpoint
from megatron.training.initialize import initialize_megatron
from megatron.core.models.gpt import GPTModel
from megatron.training import get_model
from tools.retro.text_generation.retro_api import retro_generate_and_post_process
from tools.retro.sft.sft_retro import get_tasks_args
from tools.retro.sft.dataset_conv import reformat_prompt, preprocess, reformat_prompt_short
import numpy as np
import time
import megatron.legacy.model
from megatron.training.arguments import core_transformer_config_from_args



def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
    """Builds the model.

    Args:
        pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
        post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.


    Returns:
        Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
    """
    print_rank_0('building GPT model ...')
    args = get_args()
    config = core_transformer_config_from_args(args)

    assert args.use_legacy_models, 'retro text generation only implemented for legacy models'
    
    # not support core model yet
    model = megatron.legacy.model.GPTModel(
        config,
        num_tokentypes=0,
        parallel_output=False,
        pre_process=pre_process,
        post_process=post_process
    )

    return model


def pad_neighbours_for_query_only(args, nb_tokens, pad_id, ft_neighbours):
    # take top k neighbours and padding
    neighbours_tokens = []
    retro_args = get_retro_args()
    r = retro_args.retro_gpt_retrieved_length

    if args.reuse_top:
        valid_nb_tokens = nb_tokens[:args.retro_num_neighbors]
    else:
        valid_nb_tokens = nb_tokens[ft_neighbours:args.retro_num_neighbors + ft_neighbours]

    for nb_token in valid_nb_tokens:
        if len(nb_token) >= r:
            nb_token = nb_token[:r]
        else:
            nb_token = nb_token + [pad_id] * (r - len(nb_token))
        neighbours_tokens.append(nb_token)
    print("len(nb_tokens)", len(nb_tokens))
    print("len(neighbours_tokens)", len(neighbours_tokens))
    print("args.retro_num_neighbors", args.retro_num_neighbors)

    if len(neighbours_tokens) < args.retro_num_neighbors:
        assert ValueError("neighbours are not enough, add empty ones and create mask for those empty ones")
    neighbours_tokens = np.array(neighbours_tokens)
    return neighbours_tokens


def add_text_generate_args(parser):
    """Text generation arguments."""

    parser = get_tasks_args(parser)
    group = parser.add_argument_group(title='text generation')

    group.add_argument("--temperature", type=float, default=1.0,
                       help='Sampling temperature.')
    group.add_argument("--greedy", action='store_true', default=False,
                       help='Use greedy sampling.')
    group.add_argument("--top_p", type=float, default=0.0,
                       help='Top p sampling.')
    group.add_argument("--top_k", type=int, default=0,
                       help='Top k sampling.')
    group.add_argument("--out-seq-length", type=int, default=256,
                       help='Size of the output generated text.')
    group.add_argument("--sample-input-file", type=str, default=None,
                       help='Get input from file instead of interactive mode, '
                            'each line is an input.')
    group.add_argument("--sample-output-file", type=str, default=None,
                       help='Output file got from --sample-input-file')
    group.add_argument("--num-samples", type=int, default=0,
                       help='Number of samples to generate unconditionally, '
                            'defaults to 0 and interactive conditional sampling')
    group.add_argument("--genfile", type=str,
                       help='Output file when generating unconditionally')
    group.add_argument("--recompute", action='store_true',
                       help='During generation recompute all attention '
                            'instead of using previously computed keys/values.')
    group.add_argument("--epsilon", type=float, default=0.01,
                       help="Minimum factor by which each probability is multiplied")
    group.add_argument("--debug-gen", action='store_true',
                       help="If set, additional debugging output is printed to stdout")
    group.add_argument('--length-penalty', type=float, default=1.0,
                       help='length penalty')
    group.add_argument('--gen-start-idx', type=int, default=0,
                       help='project size for adapters')
    group.add_argument('--num-gen', type=int, default=-1,
                       help='project size for adapters')
    group.add_argument('--ckpt-step', type=int, default=None,
                       help='setting ckpt step manually')
    group.add_argument("--short-format", action='store_true',
                       help='Use short format QA')
    group.add_argument("--use-retrieved-neighbours", action='store_true', default=False,
                       help='Use retrieved neighbours')
    group.add_argument('--template-id', type=int, default=0,
                       help='template id for generation,')
    return parser


def generate_samples_conditional(model):
    args = get_args()
    start = time.time()
    avg_time = []
    tokenizer = get_tokenizer()
    model.eval()
    if torch.distributed.get_rank() == 0:

        data = preprocess(args.sample_input_file, inference_only=True,
                          retrieved_neighbours=args.use_retrieved_neighbours)
        print("total rows {}".format(len(data)))
        all_data = data[args.gen_start_idx:]  # start from gen_start_idx
        if args.num_gen > 0:
            all_data = all_data[:args.num_gen]
        input_count = len(all_data)
        input_pos = 0

    terminate_runs = 0
    while True:
        torch.distributed.barrier()
        if torch.distributed.get_rank() == 0:
            sentences = []
            n_arrays = []
            print("global batch size", args.global_batch_size)
            for _ in range(args.global_batch_size):
                print(input_pos)
                if input_pos >= input_count:
                    print("reach the last row")
                    break
                else:
                    sample = all_data[input_pos]
                input_pos += 1

                if True:
                    max_target_len = args.out_seq_length
                    query, _, neighbours = sample

                    neighbours_array = pad_neighbours_for_query_only(args,
                                                                     [tokenizer.tokenize(neighbour) for neighbour in
                                                                      neighbours], tokenizer.eod, args.ft_neighbours)
                    print("neighbours_array.shape", neighbours_array.shape)
                    tokenizer = get_tokenizer()

                    if args.short_format:
                        input_tokens = reformat_prompt_short(query, neighbours, args.task, args.ft_neighbours,
                                                             max_target_len,
                                                             tokenizer, args.seq_length)
                    else:
                        input_tokens = reformat_prompt(query, neighbours, args.task, args.ft_neighbours, max_target_len,
                                                       tokenizer, args.seq_length, template_id=args.template_id)
                    raw_text = tokenizer.detokenize(input_tokens)
                    print(raw_text)
                else:
                    raise ValueError("invalid arg for task")
                sentences.append(raw_text)
            retro_args = get_retro_args()

            resp_sentences, resp_sentences_seg, scores, \
                tokens = retro_generate_and_post_process(model, prompts=sentences,
                                                         neighbours_array=neighbours_array,
                                                         tokens_to_generate=args.seq_length - retro_args.retro_gpt_chunk_length,
                                                         return_output_log_probs=False,
                                                         top_k_sampling=args.top_k,
                                                         top_p_sampling=args.top_p,
                                                         add_BOS=False,
                                                         temperature=1.0)
            print("len of resp_sentences", len(resp_sentences))
            for prompt, generation in zip(sentences, resp_sentences):
                datum = generation[len(prompt):]
                print("prompt:", generation[:len(prompt)])
                if "<|endoftext|>" in datum:
                    datum = datum[:datum.find("<|endoftext|>")].strip()
                datum = datum.replace("\n", " ")
                print("cont:", datum)
                yield datum
            avg_time.append((time.time() - start) / args.global_batch_size)
            print("avg time for each sample: ", sum(avg_time) / len(avg_time))
            start = time.time()
            if input_pos >= input_count:
                print("finish all lines")
                terminate_runs = 1
        else:
            retro_generate_and_post_process(model)

        terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
        torch.distributed.broadcast(terminate_runs_tensor, 0)
        terminate_runs = terminate_runs_tensor[0].item()

        if terminate_runs == 1:
            return


def generate_and_write_samples_conditional(model):
    args = get_args()
    if args.sample_output_file is None:
        sample_output_file = args.sample_input_file + ".out"
        print('`sample-output-file` not specified, setting '
              'it to {}'.format(sample_output_file))
    else:
        sample_output_file = args.sample_output_file
    with open(sample_output_file, 'w') as f:
        for datum in generate_samples_conditional(model):
            if torch.distributed.get_rank() == 0:
                f.write(datum + '\n')


def main():
    """Main program."""

    initialize_megatron(extra_args_provider=add_text_generate_args,
                        args_defaults={'no_load_rng': True,
                                       'no_load_optim': True})

    # Set up model and load checkpoint
    model = get_model(model_provider, wrap_with_ddp=False)
    print(model)
    args = get_args()

    if args.load is not None:
        _ = load_checkpoint(model, None, None)
    model = model[0]

    # Generate samples.
    if args.sample_input_file is not None:
        print(f"{args.sample_input_file}")
        generate_and_write_samples_conditional(model)


if __name__ == "__main__":
    main()
