

"""Pretrain GeBERT for generation tasks"""
from deepspeed.accelerator import get_accelerator
from megatron.initialize import initialize_megatron
from functools import partial
import jsonlines
import json_lines
import torch
import torch.nn.functional as F
from megatron.checkpointing import load_checkpoint
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron.core import tensor_parallel
from megatron.core.enums import ModelType
from megatron.data.gebert_utils import build_train_valid_test_datasets
from megatron.model import GeBertModel
from megatron.utils import average_losses_across_data_parallel_group
import argparse
from megatron.arguments import core_transformer_config_from_args
from megatron.training import get_model
from megatron import get_tokenizer
import numpy as np
from fairseq.utils import new_arange
from megatron.initialize import set_jit_fusion_options
# import os 
# os.environ["MASTER_PORT"] = "6666"
from tqdm import tqdm

def args_provider(parser):
    group = parser.add_argument_group(title='Extra args')

    group.add_argument('--sentence-split-type', type=str, default='direct', help='How to split the source and target.')
    group.add_argument('--masked-type', type=str, default='lazy_uniform', help='The masking type.')
    group.add_argument('--outfile', type=str, default='', help='The output file.')
    group.add_argument('--inputfile', type=str, default='', help='input file')
    group.add_argument('--extra-outfile', type=str, default=None, help='extra output file.')
    group.add_argument('--length-predict', action='store_true', help='if use length prediction.')
    group.add_argument('--max-predict-length', type=int, default=2048, help='The max predict length.')
    group.add_argument('--max-iter', type=int, default=1, help='The decoding steps.')
    group.add_argument('--length-beam', type=int, default=1, help='The length beam.')
    group.add_argument('--load-LP-module', action='store_true', help='if we adopt length prediction')
    return parser


def model_provider(pre_process=True, post_process=False):
    """Build the model."""

    print_rank_0('building GEBERT model ...')
    args = get_args()
    config = core_transformer_config_from_args(args)
    model = GeBertModel(
            config=config,
            num_tokentypes=0,
            parallel_output=True,
            pre_process=pre_process,
            post_process=True)
    return model

def inferece_LP_length_beam(model, batch_sample, tokenizer, max_seq_length=1023, max_iter=10, device=None, length_beam=1):
    # assert len(batch_sample)==1
    # import pdb; pdb.set_trace()
    source_list = []
    target_list = []
    # index_mask_list = []
    attention_index_list = []
    for single_sample in batch_sample:
        source = single_sample["source"]
        target = single_sample["target"]
        tokenized_ids_source = tokenizer.tokenize(source)[:max_seq_length]
        tokenized_ids_target = tokenizer.tokenize(target)[:max_seq_length]
        source_list.append(tokenized_ids_source)
        target_list.append(tokenized_ids_target)

    max_length_source = max_seq_length
    source_list = [i + [tokenizer.pad for _ in range(max_length_source-len(i))] for i in source_list]
    batch_source = torch.tensor(source_list).to(device)
    source_padding_mask = batch_source.ne(tokenizer.pad).type_as(batch_source)
    # import pdb; pdb.set_trace()
    output_tensors = model(batch_source, source_padding_mask)
    length_out = model.forward_length(output_tensors[1], source_padding_mask).detach()
    # length_tgt = model.forward_length_prediction(length_out).tolist()
    
    # mutiple_length_beam
    # import pdb; pdb.set_trace()
    length_tgt = length_out.topk(length_beam, dim=-1, largest=True, sorted=True)[1].tolist()


    result_GL = []
    result_LP = []
    result_LP_one = []

    for sample_index in range(len(batch_sample)):
        result_GL.append(len(target_list[sample_index]))
        result_LP.append(sum(length_tgt[sample_index])/length_beam)
        result_LP_one.append(length_tgt[sample_index][0])

    return result_GL, result_LP, result_LP_one


import deepspeed
if __name__ == "__main__":
    ds_dict = {
        "train_batch_size" : 1,
        "train_micro_batch_size_per_gpu": 1,
        "steps_per_print": 10,

        "zero_optimization": {
        "stage": 0
        },

        "gradient_clipping": 1.0,
        "prescale_gradients": False,

        "fp16": {
        "enabled": True,
        "loss_scale": 0,
        "loss_scale_window": 500,
        "hysteresis": 2,
        "min_loss_scale": 1,
        "initial_scale_power": 11
        },

        "bf16": {
        "enabled": False
        },

        "wall_clock_breakdown" : False
        }


    device = torch.device("cuda")
    initialize_megatron(extra_args_provider=args_provider,args_defaults={'tokenizer_type': 'HFTokenizer'})
    args = get_args()
    tokenizer = get_tokenizer()
    model = get_model(model_provider)
    if args.deepspeed:
        if get_accelerator().device_name() == 'cuda':
            set_jit_fusion_options()
        args.deepspeed_config_dict = ds_dict
        model, optimizer, _, opt_param_scheduler = deepspeed.initialize(
                    model=model[0],
                    optimizer=None,
                    args=args,
                    lr_scheduler=None,
                    mpu=None,
                    config=args.deepspeed_config_dict,
                )
        # import pdb; pdb.set_trace()
        _ = load_checkpoint([model], None, None, strict=True, load_only_weights=False)
    else:
        _ = load_checkpoint(model, None, None, strict=False, load_only_weights=False)
        assert len(model) == 1, "Above condition should have caught this"
        model = model[0].to(device)
    # assert len(model) == 1, "Above condition should have caught this"
    
    model.eval()
    with open(args.inputfile, 'rb') as f: 
        item_list = [item for item in json_lines.reader(f)]
    max_seq_length = args.seq_length
    max_iter = args.max_iter
    batch_size = args.micro_batch_size 
    batch_index = 0 
    sample_number = len(item_list)
    final_result_GL = []
    final_result_LP = []
    final_result_LP_one = []

    for index in tqdm(range(0,len(item_list),batch_size)):
        batch_sample = item_list[index:index+batch_size]
        result_GL, result_LP, result_LP_one = inferece_LP_length_beam(model, batch_sample, tokenizer, max_seq_length, max_iter, device, args.length_beam)
        final_result_GL.extend(result_GL)
        final_result_LP.extend(result_LP)
        final_result_LP_one.extend(result_LP_one)

    import pdb; pdb.set_trace()

