import numpy as np
import torch, os
from functools import partial
from megatron import get_args
try:
    import torch_npu
    _IS_NPU_AVAILABLE = True
except ImportError:
    _IS_NPU_AVAILABLE = False
from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron import get_global_rank_moe_loss
from megatron.core import tensor_parallel
from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron.model import GPTModel, MedusaModel
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group
from megatron.utils import get_cu_seqlens
from megatron.utils import round_multiple
from tools import tracking_utils
from megatron.core.utils import torch_version
from megatron import get_micro_batch_id, get_num_microbatches


GLOBAL_LOSS = None


def model_provider(pre_process=True, post_process=True):
    """Build the model."""
    args = get_args()

    print_rank_0('building GPT model ...')
    model_class = MedusaModel if args.medusa_num_heads is not None else GPTModel
    model = model_class(
        num_tokentypes=0,
        parallel_output=True,
        pre_process=pre_process,
        post_process=post_process
    )

    return model


def get_batch(data_iterator, p2p_tokens=None):
    """Generate a batch"""
    args = get_args()
    tokenizer = get_tokenizer()

    keys = ["plain_text", "label"]
    datatype = torch.int64

    # TODO: this is pretty hacky, find a better way
    if args.reduce_pipeline_data_io and (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
        if not (args.varlen_end_to_end or args.varlen_attention) or p2p_tokens is None:
            return None, None, None, None, None, None, 0, None, p2p_tokens

        tokens = p2p_tokens[:, :-1].contiguous()
        labels = p2p_tokens[:, 1:].contiguous()

        cu_seqlens = get_cu_seqlens(data=tokens, eod_token=tokenizer.eod, pad_token=tokenizer.pad)
        cu_seqlens_host = cu_seqlens.to("cpu")
        if args.varlen_end_to_end:
            round_seqlen = round_multiple(cu_seqlens[1].item(), 128)
            tokens = tokens[:, :round_seqlen]
            labels = labels[:, :round_seqlen]
        # megatron original mask
        attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
            tokens,
            tokenizer.eod,
            args.reset_position_ids,
            args.reset_attention_mask,
            args.eod_mask_loss,
            bos_mask_loss=args.bos_mask_loss,
            bos_token=tokenizer.bos,
            skip_attention_mask=(args.use_flash_attn and args.varlen_attention) if _IS_NPU_AVAILABLE else args.use_flash_attn
        )
        # mask the inputs of labels for (input, target) format, see megatron.data.chat_dataset.ChatDataset
        loss_mask[labels < 0] = 0.0
        # mask the pad tokens for pad_token in pretraining
        if args.dataset_type == "pretrain_processed":
            loss_mask[labels==tokenizer.pad] = 0.0
        sample_count = 0
        if args.splicing_weight:
            loss_mask = loss_mask.view(-1).float()
            loss_mask, sample_count = weight_loss_mask(loss_mask, cu_seqlens)
        return None, None, None, None, None, cu_seqlens, sample_count, cu_seqlens_host, p2p_tokens

    if args.batch_data_scatter:
        data_parallel_world_size = mpu.get_data_parallel_world_size()
        if mpu.get_data_parallel_rank() == 0:
            # Broadcast data.
            if data_iterator is not None:
                data = next(data_iterator)
            else:
                data = None
            data_b = tensor_parallel.broadcast_data(keys, data, datatype)
            data_index = list(range(0, data_parallel_world_size))
            send_tensor = torch.cat((data_b['plain_text'], data_b['label']))
            send_tensor_list = [send_tensor[i:2 * args.micro_batch_size * data_parallel_world_size:data_parallel_world_size].contiguous() for i in data_index]
        else:
            send_tensor_list = None
        recv_tensor = torch.zeros(2 * args.micro_batch_size, args.seq_length, device=torch.cuda.current_device(), dtype=datatype)
        # Scatter data.
        torch.distributed.scatter(
            recv_tensor,
            send_tensor_list,
            list(mpu._DATA_PARALLEL_GLOBAL_RANKS)[0],
            mpu.get_data_parallel_group(),
        )
        data_b = {}
        data_b['plain_text'] = recv_tensor[:args.micro_batch_size]
        data_b['label'] = recv_tensor[args.micro_batch_size:]
    else:
        # Broadcast data.
        if data_iterator is not None:
            data = next(data_iterator)
        else:
            data = None
        # ==================================
        # Handle special data formats that require masks
        # This is compatible with two cases:
        # case1: data format contains text
        # case2: data format has both input and output
        # ==================================
        data_b = tensor_parallel.broadcast_data(keys, data, datatype)

    cu_seqlens = None
    cu_seqlens_host = None
    p2p_tokens = None

    # Unpack.
    if "text" in keys:
        tokens_ = data_b['text'].long()
        labels = tokens_[:, 1:].contiguous()
        tokens = tokens_[:, :-1].contiguous()

        # Get the masks and position ids.
        attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
            tokens,
            tokenizer.eod,
            args.reset_position_ids,
            args.reset_attention_mask,
            args.eod_mask_loss,
            skip_attention_mask=(args.use_flash_attn and args.varlen_attention) if _IS_NPU_AVAILABLE else args.use_flash_attn)
        sample_count = None
    else:
        labels = data_b["label"].long().contiguous()
        tokens = data_b["plain_text"].long().contiguous()
        if args.reduce_pipeline_data_io \
                and mpu.is_pipeline_first_stage() \
                and (args.varlen_end_to_end or args.varlen_attention):
            p2p_tokens = torch.cat((tokens, labels[:, -1].view([args.micro_batch_size, 1])), 1)

        # Pre-compute cu_seqlens, if exact_varlen is enabled, directly truncate input sequence
        if args.varlen_end_to_end or args.varlen_attention:
            cu_seqlens = get_cu_seqlens(data=tokens, eod_token=tokenizer.eod, pad_token=tokenizer.pad)
            cu_seqlens_host = cu_seqlens.to("cpu")
            if args.varlen_end_to_end:
                round_seqlen = round_multiple(cu_seqlens[1].item(), 128)
                tokens = tokens[:, :round_seqlen]
                labels = labels[:, :round_seqlen]
        # megatron original mask
        attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
            tokens,
            tokenizer.eod,
            args.reset_position_ids,
            args.reset_attention_mask,
            args.eod_mask_loss,
            bos_mask_loss=args.bos_mask_loss,
            bos_token=tokenizer.bos,
            skip_attention_mask=(args.use_flash_attn and args.varlen_attention) if _IS_NPU_AVAILABLE else args.use_flash_attn
        )
        # mask the inputs of labels for (input, target) format, see megatron.data.chat_dataset.ChatDataset
        loss_mask[labels < 0] = 0.0
        # mask the pad tokens for pad_token in pretraining
        if args.dataset_type == "pretrain_processed":
            loss_mask[labels==tokenizer.pad] = 0.0
        sample_count = 0
        if args.splicing_weight:
            loss_mask = loss_mask.view(-1).float()
            loss_mask, sample_count = weight_loss_mask(loss_mask, cu_seqlens)

    return tokens, labels, loss_mask, attention_mask, position_ids, cu_seqlens, sample_count, cu_seqlens_host, p2p_tokens

def loss_func(loss_mask, output_tensor, moe_loss=None, tokens=None):
    losses = output_tensor.float()
    loss_mask = loss_mask.view(-1).float()
    args = get_args()
    all_loss = losses.view(-1) * loss_mask
    loss = all_loss.sum()

    # Check individual rank losses are not NaN prior to DP all-reduce.
    if args.check_for_nan_in_loss_and_grad:
        global_rank = torch.distributed.get_rank()
        assert not loss.isnan(), (
            f'Rank {global_rank}: found NaN in local forward loss calculation. '
            f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}'
        )

    sum = loss_mask.sum()
    if sum != 0 and not args.evaluating:
        loss = loss / sum

    # Reduce loss for logging.
    if args.avg_loss_comm_optimize and not args.evaluating:
        global GLOBAL_LOSS
        if GLOBAL_LOSS is None:
            GLOBAL_LOSS = torch.zeros(1, device=torch.cuda.current_device(), dtype=loss.dtype)
        GLOBAL_LOSS += loss.clone().detach()
        if get_micro_batch_id() == get_num_microbatches():
            averaged_loss = [GLOBAL_LOSS.detach().clone()]
            GLOBAL_LOSS.zero_()
        else:
            averaged_loss = [loss.clone().detach()]
    else:
        averaged_loss = average_losses_across_data_parallel_group([loss])
    if args.evaluating:
        num_eval_tokens = torch.cuda.FloatTensor([sum])
        torch.distributed.all_reduce(num_eval_tokens,
                                    group=mpu.get_data_parallel_group())
        if moe_loss is not None:
            return loss, {'lm loss': averaged_loss[0], 'moe_loss': moe_loss, 'num_eval_tokens':num_eval_tokens[0]}

        return loss, {'lm loss': averaged_loss[0], 'num_eval_tokens':num_eval_tokens[0]}
    res_dict = dict()
    res_dict['lm loss'] = averaged_loss[0]
    if moe_loss is not None:
        res_dict['moe_loss'] =  moe_loss
    if args.record_spike_loss:
        res_dict['all_loss'] = all_loss.detach().to("cpu")
        res_dict['tokens'] = tokens
    return loss, res_dict

def loss_func_splicing_weight(loss_mask, output_tensor, moe_loss=None):
    losses = output_tensor.float()
    loss = torch.sum(losses.view(-1) * loss_mask)
    # Reduce loss for logging.
    averaged_loss = average_losses_across_data_parallel_group([loss])
    if moe_loss is not None:
        return loss, {'lm loss': averaged_loss[0], 'moe_loss': moe_loss}
    return loss, {'lm loss': averaged_loss[0]}


def weight_loss_mask(loss_mask, cu_seqlens):
    sample_count = 0
    loss_mask_weight = torch.zeros_like(loss_mask)
    for i in range(len(cu_seqlens) - 1):
        sum = torch.sum(loss_mask[cu_seqlens[i]:cu_seqlens[i + 1]])
        if sum != 0:
            sample_count += 1
            loss_mask_weight[cu_seqlens[i]:cu_seqlens[i + 1]] = \
                loss_mask[cu_seqlens[i]:cu_seqlens[i + 1]] / sum
    return loss_mask_weight, sample_count


def forward_step(data_iterator, model, optimizer=None, p2p_tokens=None):
    """Forward step."""
    args = get_args()
    timers = get_timers()

    # Get the batch.
    timers('batch-generator', log_level=2).start()
    tokens, labels, loss_mask, attention_mask, position_ids, cu_seqlens, sample_count, cu_seqlens_host, _p2p_tokens = get_batch(
        data_iterator, p2p_tokens)
    timers('batch-generator').stop()
    sample_num = torch.tensor([sample_count], device=_p2p_tokens.device if tokens is None else tokens.device)
    if args.splicing_weight:
        torch.distributed.all_reduce(sample_num, group=mpu.get_data_parallel_group())
    output_tensor = model(tokens, position_ids, attention_mask,
                          labels=labels, cu_seqlens=cu_seqlens, cu_seqlens_host=cu_seqlens_host)

    if args.splicing_weight:
        return output_tensor, partial(loss_func_splicing_weight, loss_mask), sample_num, _p2p_tokens
    if args.record_spike_loss:
        if tokens is None:
            data = {'inputs':None, 'labels':None}
        else:
            data = {'inputs':tokens.to("cpu"), 'labels':labels.to("cpu")}
        return output_tensor, partial(loss_func, loss_mask, tokens=data), sample_num, _p2p_tokens
    return output_tensor, partial(loss_func, loss_mask), sample_num, _p2p_tokens


def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid, and test datasets."""
    args = get_args()

    print_rank_0('> building train, validation, and test datasets '
                 'for GPT ...')

    if args.dataset_type == "pretrain_processed":
        from megatron.data.pretrain_processed_dataset import build_train_valid_test_datasets
    else:
        assert args.dataset_type == "chat"
        from megatron.data.chat_dataset import build_train_valid_test_datasets

    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        data_prefix=args.data_path,
        data_impl=args.data_impl,
        splits_string=args.split,
        train_valid_test_num_samples=train_val_test_num_samples,
        seq_length=args.seq_length,
        seed=args.seed,
        skip_warmup=(not args.mmap_warmup),
        train_data_prefix=args.train_data_path,
        valid_data_prefix=args.valid_data_path,
        test_data_prefix=args.test_data_path)
    print_rank_0("> finished creating GPT datasets ...")
    all_data_paths = [args.data_path, args.train_data_path, args.valid_data_path, args.test_data_path]
    for data_paths in all_data_paths:
        if data_paths is None:
            continue
        for data_path in data_paths: 
            tracking_utils.track_artifact_rank_0(data_path, "dataset", [data_path], [])

    return train_ds, valid_ds, test_ds


if __name__ == "__main__":
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    if _IS_NPU_AVAILABLE:
        if os.environ.get("DETERMINISTIC_MODE", "False") == "True":
            torch.use_deterministic_algorithms(True)
        else:
            torch.use_deterministic_algorithms(False)
    else:
        torch.use_deterministic_algorithms(True)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    # torch.utils.deterministic.fill_uninitialized_memory = False
    torch_major, torch_minor = torch_version()
    if torch_major >=2 and torch_minor >=2:
        torch.utils.deterministic.fill_uninitialized_memory = False

    pretrain(train_valid_test_datasets_provider, model_provider,
             ModelType.encoder_or_decoder,
             forward_step,
             args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}
    )
