# coding=utf-8
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Classes related to initialize megatron utilities.
"""
import os
import time
import torch
from .fused_kernels import load as fused_kernels_load


def compile_megatron_dependencies(args, master_port="29501"):
    """
    args need to have the following attr:
        max_length:
        num_attention_heads: 
        tensor_model_parallel_size: 
        micro_batch_size:
        fp16:
        bf16:
        masked_softmax_fusion:
    """
    # if distributed is not initialized, manually initialize it. Required for compiling.
    if not torch.distributed.is_initialized():
        os.environ['RANK'] = "0"
        os.environ["WORLD_SIZE"] = "1"
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = master_port
        torch.distributed.init_process_group(backend="nccl")

    # check args
    if not hasattr(args, "rank"):
        setattr(args, "rank", get_rank())

    # =========================
    # Compile dataset C++ code.
    # =========================
    # TODO: move this to ninja
    if args.rank == 0:
        start_time = time.time()
        print('> compiling dataset index builder ...')
        from megatron.data.dataset_utils import compile_helper
        compile_helper()
        print('>>> done with dataset index builder. Compilation time: {:.3f} '
              'seconds'.format(time.time() - start_time), flush=True)

    # ==================
    # Load fused kernels
    # ==================

    # Custom kernel constraints check.
    max_seq_len = args.max_length
    attn_batch_size = \
        (args.num_attention_heads / getattr(args, "tensor_model_parallel_size", 1)) * \
        getattr(args, "micro_batch_size", 1)
    # Constraints on sequence length and attn_batch_size to enable warp based
    # optimization and upper triangular optimization (for causal mask)
    custom_kernel_constraint = max_seq_len > 16 and max_seq_len <= 4096 and \
        max_seq_len % 4 == 0 and attn_batch_size % 4 == 0
    # Print a warning.
    if not ((args.fp16 or args.bf16) and
            custom_kernel_constraint and
            getattr(args, "masked_softmax_fusion", False)):
        if args.rank == 0:
            print('WARNING: constraints for invoking optimized'
                  ' fused softmax kernel are not met. We default'
                  ' back to unfused kernel invocations.', flush=True)
            print("Here are the details: fp16 {}, bf16 {}, constraint {}, fusion {}.".format(
                args.fp16, args.bf16, [max_seq_len, attn_batch_size], 
                getattr(args, "masked_softmax_fusion", False)), flush=True)
    
    # Always build on rank zero first.
    if args.rank == 0:
        start_time = time.time()
        print('> compiling and loading fused kernels ...', flush=True)
        if torch.cuda.device_count() > 0: # Skip when CPU-only
            fused_kernels_load(args)
        torch.distributed.barrier()
    else:
        torch.distributed.barrier()
        fused_kernels_load(args)
    # Simple barrier to make sure all ranks have passed the
    # compilation phase successfully before moving on to the
    # rest of the program. We think this might ensure that
    # the lock is released.
    torch.distributed.barrier()
    if args.rank == 0:
        print('>>> done with compiling and loading fused kernels. '
              'Compilation time: {:.3f} seconds'.format(
                  time.time() - start_time), flush=True)


def get_rank():
    if torch.distributed.is_initialized():
        rank = torch.distributed.get_rank()
        if 'AZUREML_EXPERIMENT_ID' in os.environ and torch.distributed.get_rank() % torch.cuda.device_count() == 0:
            rank = 0
        return rank
    else:
        return 0
