                                                                                                          
                                                                                                            
                                                      
                                                                                        

from contextlib import asynccontextmanager
from datetime import timedelta
import sys
import time
import os
import json
import threading
import subprocess
import random
import psutil
import datetime
import signal
import traceback
import multiprocessing as mp

from fastapi import FastAPI, HTTPException
from fastapi import Request
from fastapi.responses import JSONResponse, Response
from torch.distributed import TCPStore, get_rank, DistNetworkError
import fastapi
import torch
import uvicorn
import httpx
import requests

from gpatch.core.utils import print_with_rank_and_datetime

MONITOR_INTERVAL = 30
REQ_TIMEOUT = 5
MAX_RETRIES = 5


def kill_all():
    print(f'monitor kill_all', flush=True)
    for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
        pname = proc.info['name']
        cmdline = proc.info['cmdline']
        if 'python' in pname.lower():
                                                                                                        
            pid = proc.info['pid']
            if pid != os.getpid():
                print(f'monitor kill_all kill {pid} {pname} {cmdline}', flush=True)
                os.kill(pid, signal.SIGKILL)
    sys.exit(0)


def _start_server_loop(monitor_server_ip, port):
    print(f"monitor server start {monitor_server_ip=} {port=}", flush=True)

    @asynccontextmanager
    async def lifespan(app: fastapi.FastAPI):
        yield

    app = fastapi.FastAPI(lifespan=lifespan)
    exit_flag = False
    last_time_progress = time.time()

    @app.post("/set_exit")
    async def set_exit(request: Request):
        nonlocal exit_flag
        await request.json()
        if not exit_flag:
            print(f"monitor set_exit", flush=True)
        exit_flag = True
        return JSONResponse(content={"ret": "ok"})

    @app.post("/mark_made_progress")
    async def mark_made_progress(request: Request):
        nonlocal last_time_progress
        last_time_progress = time.time()
        return JSONResponse(content={"ret": "ok"})

    @app.post("/check")
    async def check(request: Request):
        return JSONResponse(content={
            "exit": exit_flag,
            'last_time_progress': last_time_progress,
        })

    uvicorn.run(
        app,
        host=monitor_server_ip,
        port=port,
        log_level='error',
        use_colors=False,
        timeout_keep_alive=60,
        ssl_keyfile=None,
        ssl_certfile=None,
        ssl_ca_certs=None,
        ssl_cert_reqs=None
    )


def start_server_loop(monitor_server_ip, port):
    if os.fork() > 0:
        sys.exit(0)
    os.setsid()
    if os.fork() > 0:
        sys.exit(0)

    try:
        _start_server_loop(monitor_server_ip, port)
    except:
        traceback.print_exc()


def start_monitor_server_in_background(monitor_server_ip, port):
    mp_ctx = mp.get_context('spawn')
    p = mp_ctx.Process(target=start_server_loop, args=(monitor_server_ip, port))
    p.start()


def mark_made_progress(monitor_server_ip, port):
    if torch.distributed.get_rank() != 0:
        return

    for try_i in range(MAX_RETRIES):
        try:
            url = f'http://{monitor_server_ip}:{port}/mark_made_progress'
            resp = requests.post(url, json={}, timeout=REQ_TIMEOUT)
            resp.raise_for_status()
            return
        except:
            print(f'monitor mark_made_progress dist network err {try_i}', flush=True)
            time.sleep(1)


def _start_client_loop(role, rank, monitor_server_ip, port, max_time_wo_progress):
    print(f"monitor client start {rank=} {monitor_server_ip=} {port=}", flush=True)
    e_cnt = 0

    while True:
        try:
            url = f'http://{monitor_server_ip}:{port}/check'
            resp = requests.post(url, json={}, timeout=REQ_TIMEOUT)
            resp.raise_for_status()
            resp_json = resp.json()
        except:
            e_cnt += 1
            print(f'monitor client dist network err {e_cnt}', flush=True)
            if e_cnt < MAX_RETRIES:
                time.sleep(MONITOR_INTERVAL)
                continue
            else:
                kill_all()
        else:
            e_cnt = 0

                     
        if resp_json['exit']:
            print(f"monitor client detected kill signal {rank=} start exiting in {2 * MONITOR_INTERVAL} seconds")
            time.sleep(2 * MONITOR_INTERVAL)                                              
            kill_all()

                                  
        last_time_progress = resp_json['last_time_progress']
        now = time.time()
        if now - last_time_progress > max_time_wo_progress:
            print(f'monitor no progress for {now - last_time_progress} seconds')
            set_exit_flag(monitor_server_ip, port)
            time.sleep(2 * MONITOR_INTERVAL)
            kill_all()

        time.sleep(MONITOR_INTERVAL)


def start_client_loop(role, rank, monitor_server_ip, port, max_time_wo_progress):
    if os.fork() > 0:
        sys.exit(0)
    os.setsid()
    if os.fork() > 0:
        sys.exit(0)

    try:
        _start_client_loop(role, rank, monitor_server_ip, port, max_time_wo_progress)
    except:
        traceback.print_exc()


def start_monitor_client_in_background(role, monitor_server_ip, port, max_time_wo_progress):
    mp_ctx = mp.get_context('spawn')
    p = mp_ctx.Process(
        target=start_client_loop,
        args=(
            role,
            torch.distributed.get_rank(),
            monitor_server_ip,
            port,
            max_time_wo_progress,
        )
    )
    p.start()


def set_exit_flag(monitor_server_ip, port):
    print("monitor client sending kill signal to monitor server", flush=True)
    for try_i in range(MAX_RETRIES):
        try:
            url = f'http://{monitor_server_ip}:{port}/set_exit'
            resp = requests.post(url, json={}, timeout=REQ_TIMEOUT)
            resp.raise_for_status()
            return
        except:
            print(f'monitor exit dist network err {try_i}', flush=True)
            time.sleep(1)

                                                      
                                                                                     
                                  
                   
    print(f'monitor failed set exit flag', flush=True)
    os.kill(os.getpid(), signal.SIGKILL)
