import atexit
import os
import os
import re
import requests
import subprocess
import sys
import sys
import threading
import time

import torch

from megatron.training import get_args
from megatron.legacy import fused_kernels

try:
    from mpatch.training.priv.wecube import wecube_url
except ImportError:
    from mpatch.training.pub.wecube import wecube_url


def heartbeat(task_name=None,):
    def run_command(command):
        result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, shell=True)
        if result.returncode == 0:
            return result.stdout.strip()
        else:
            return " "

    def get_time():
        return int(time.time())

    def get_report_ip():
        ip = run_command("ifconfig | grep inet | head -n 1")
        pattern = r"\b(?:\d{1,3}\.){3}\d{1,3}\b"
        match = re.search(pattern, ip)
        if match == "None" or match == None:
            return " "
        return match.group()

    def get_gpu_num():
        if os.path.exists("/etc/mpi/hostfile") == False:
            return 8
        f = open("/etc/mpi/hostfile", "r")
        gpu_num = len(f.readlines()) * 8
        f.close()
        return gpu_num

    def get_cmd():
        return " ".join(sys.argv)

    def get_framework():
        return "gcore-dev"

    def get_branch():
        return run_command("git branch | head -n 1")[2:]

    def get_commit_id():
        return run_command("git log | grep commit | head -n 1")[7:]

    def get_commit_timestamp():
        return run_command("git log | grep Date | head -n 1")[6:]
    
    def get_model_arch_and_size() -> str:
        args = get_args()
        taskname = f"{args.model_arch}/{args.num_layers}-{args.hidden_size}-{args.ffn_hidden_size}"
        return taskname

    model_id = get_time()

    def background_heartbeat():
        k2v = {}
        k2v["biz_id"] = 12331
        k2v["report_ip"] = get_report_ip()
        k2v["model_id"] = model_id
        k2v["gpu_num"] = get_gpu_num()
        k2v["cmd"] = get_cmd()
        k2v["framework"] = get_framework()
        k2v["git_branch"] = get_branch()
        k2v["commit_id"] = get_commit_id()
        k2v["commit_timestamp"] = get_commit_timestamp()
        k2v["model_arch_size"] = get_model_arch_and_size()
        if task_name is not None:
            k2v["task_name"] = task_name

        def goodbye():
            try:
                k2v["time"] = get_time()
                k2v["exit"] = "true"
                response = requests.post(wecube_url, json=k2v)
            except:
                pass
        atexit.register(goodbye)

        while True:
            try:
                k2v["time"] = get_time()
                k2v["exit"] = "false"
                response = requests.post(wecube_url, json=k2v)
            except:
                pass
            time.sleep(60)

    if torch.distributed.get_rank() == 0:
        thread = threading.Thread(target=background_heartbeat)
        thread.daemon = True
        thread.start()


def initialize_megatron_wrapper(fn):
    def wrapper(*args, **kwargs):
        fn(*args, **kwargs)

        args = get_args()
        task_name = getattr(args, "px_task_name", None)
        heartbeat(task_name)

    return wrapper


def _compile_dependencies():

    args = get_args()

                               
                               
                               
                              

    if torch.distributed.get_rank() == 0:
        start_time = time.time()
        print("> compiling dataset index builder ...")
        from megatron.core.datasets.utils import compile_helpers

        compile_helpers()
        print(
            ">>> done with dataset index builder. Compilation time: {:.3f} "
            "seconds".format(time.time() - start_time),
            flush=True,
        )

                        
                        
                        

                                      
    seq_len = args.seq_length
    attn_batch_size = (
        args.num_attention_heads / args.tensor_model_parallel_size
    ) * args.micro_batch_size
                                                                             
                                                                      
    custom_kernel_constraint = (
        seq_len > 16 and seq_len <= 16384 and seq_len % 4 == 0 and attn_batch_size % 4 == 0
    )
                      
    if not ((args.fp16 or args.bf16) and custom_kernel_constraint and args.masked_softmax_fusion):
        if args.rank == 0:
            print(
                "WARNING: constraints for invoking optimized"
                " fused softmax kernel are not met. We default"
                " back to unfused kernel invocations.",
                flush=True,
            )

    if args.no_fused_kernel:
        return

                                      
    if torch.distributed.get_rank() == 0:
        start_time = time.time()
        print("> compiling and loading fused kernels ...", flush=True)
        fused_kernels.load(args)
        torch.distributed.barrier()
    else:
        torch.distributed.barrier()
        fused_kernels.load(args)
                                                           
                                                            
                                                          
                           
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(
            ">>> done with compiling and loading fused kernels. "
            "Compilation time: {:.3f} seconds".format(time.time() - start_time),
            flush=True,
        )
