import argparse
import asyncio
from multiprocessing import Process, Event
from flex_attention_vllm.launcher import SystemLauncher
from flex_attention_vllm.vllm_simulator.api_server import start_simulation
from types import SimpleNamespace
import os
import sys
import time
import datetime
import signal
import subprocess
from typing import Dict, List
from flex_attention_vllm.logger import init_logger

logger = init_logger(__name__)

replicas_ip_port = "127.0.0.1:8081,127.0.0.1:8082,127.0.0.1:8083,127.0.0.1:8084,127.0.0.1:8085,127.0.0.1:8086,127.0.0.1:8087,127.0.0.1:8088"

start_proxy_enable = True
start_service_enable = False
stop_service_enable = False

global_scheduler_type_list = ["double_hash","cache_affinity", "preble","min_ttft","min_pending_input"]

# balance_type_list = ["dh_least_loaded","dh_cache_affinity","dh_min_ttft","no_balance"] #abalation
balance_type_list = ["","ttft_slo_aggresive"] 
qps_list = [3,4]

request_num_list = [4000] # 
ports = [8081, 8082, 8083, 8084, 8085, 8086, 8087, 8088]
devices = [0, 1, 2, 3, 4, 5, 6, 7]
vllm_log_keyword = "DEBUG"

model = "Qwen/Qwen2.5-7B-Instruct" # "Qwen/Qwen2.5-7B-Instruct"
model_path = "Qwen/Qwen2.5-7B-Instruct" 
model_name = "qwen2.5-7B-Instruct"
replica_num = 8
replica_dram = 64
cache_capacity = replica_dram << 30
block_size = 128
kv_cache_size_per_token = 57344 # 57344 for Qwen2.5-7B-Instruct
max_model_len = 20480
workspace_path = "./"
request_num = 500
request_share_ratio = 0.4 
request_first_prefill_len = 64
request_decode_len = 16
request_round = 10
request_active_timeout = 300 #30
dataset_type = "conversatio-toolagent" 
ct_ratio = 4000
process_dataset_online = False
process_dataset_online_list = [False]
request_dataset_file_name = "processed_conversation_trace.jsonl" #
request_native_qps = 6.5 #
ttft_slo = 5
abnormal_requests_rario = 0 # toolagent_trace:0.39,mooncake-conversation:0
abnormal_requests_num = 0
warm_up_requests_num = 500
warm_up_qps = 0.3 # 8 replica 0.3 for conversation,0.5 for toolagent
requests_num_dataset_start = 0 
prefill_tpot = 0.0000635 # 910b4 and Qwen2.5-7B-Instruct
dh_recompute_punish_ratio = 1

dh_balance_decode_busy_enable = False
decode_busy_threshold = ttft_slo
dh_first_balance_budget_thredhold = int((ttft_slo/2) * (1/prefill_tpot))
dh_rebalance_thredhold = int(ttft_slo * (1/prefill_tpot))     
dh_first_balance_ttft_thredhold = dh_rebalance_thredhold
replica_slo_budget = dh_rebalance_thredhold

dh_cancel_rebalance_req = False
discard_req_flag = False
discard_req_threshold = 5 # second
dh_window_duration = 30 # minute
dh_replica_pending_req_threshold = 1
dh_rebalance_waiting_latency_thredhold = 3
preble_window_duration = 3 # minute
busy_prefill_interval = 2
dh_extend_replica = False



server_list: Dict[int, List[int]] = {}

def build_namespace(global_scheduler_type: str, qps: float, 
                    model, model_path, model_name, 
                    replica_num, replica_dram,
                    request_num, request_share_ratio,
                    request_first_prefill_len, request_decode_len,
                    request_round, request_active_timeout,
                    dataset_type, request_dataset_file,
                    request_dataset_dir,process_dataset_online,
                    balance_type, ttft_slo, replica_slo_budget,
                    dh_recompute_punish_ratio, dh_rebalance_thredhold,
                    dh_rebalance_waiting_latency_thredhold,dh_extend_replica,
                    result_path
                    ) -> argparse.Namespace:
    """构建参数命名空间"""
    args_dict = {
        "replicas_ip_port": replicas_ip_port,
        # "replica_urls": "http://127.0.0.1:8081/v1/chat/completions,http://127.0.0.1:8082/v1/chat/completions,http://127.0.0.1:8083/v1/chat/completions,http://127.0.0.1:8084/v1/chat/completions",
        "model": model,
        "model_path": model_path,
        "model_name": model_name,
        "replica_num": replica_num,
        "replica_dram": replica_dram,
        "cache_capacity": cache_capacity,
        "block_size": block_size,
        "kv_cache_size_per_token": kv_cache_size_per_token,
        "max_model_len": max_model_len,
        "request_generate_qps": qps,
        "request_native_qps":request_native_qps,
        "request_num": request_num, # 3000
        "warm_up_requests_num": warm_up_requests_num,
        "warm_up_qps": warm_up_qps,
        "requests_num_dataset_start": requests_num_dataset_start,
        "request_share_ratio": request_share_ratio, # 0.6
        "request_first_prefill_len": request_first_prefill_len, #64
        "request_decode_len": request_decode_len, #16
        "request_round": request_round, # 9
        "request_active_timeout": request_active_timeout,# 5, second
        "abnormal_requests_rario": abnormal_requests_rario,
        "abnormal_requests_num": abnormal_requests_num,
        "dataset_type": dataset_type, #"simulated-mooncake/conversation"
        "process_dataset_online":process_dataset_online,
        "ct_ratio": ct_ratio,
        "request_dataset_file": request_dataset_file,
        "request_dataset_dir": request_dataset_dir,
        "global_scheduler_type": global_scheduler_type,
        "balance_type": balance_type, # "ttft_slo_aggresive"
        "prefill_tpot": prefill_tpot,
        "dh_balance_decode_busy_enable": dh_balance_decode_busy_enable,
        "decode_busy_threshold": decode_busy_threshold,
        "dh_first_balance_ttft_thredhold":dh_first_balance_ttft_thredhold,
        "dh_rebalance_thredhold": dh_rebalance_thredhold,
        "dh_rebalance_waiting_latency_thredhold":dh_rebalance_waiting_latency_thredhold,
        "dh_recompute_punish_ratio":dh_recompute_punish_ratio,
        "dh_cancel_rebalance_req":dh_cancel_rebalance_req,
        "discard_req_flag":discard_req_flag,
        "discard_req_threshold":discard_req_threshold,
        "replica_slo_budget":replica_slo_budget,
        "ttft_slo": ttft_slo,
        "update_replica_info": True,
        "window_duration": 30,
        "dh_window_duration": dh_window_duration,
        "dh_replica_pending_req_threshold":dh_replica_pending_req_threshold,
        "preble_window_duration": preble_window_duration,
        "result_path": result_path,
        "dh_extend_replica":dh_extend_replica,
        "busy_prefill_interval":busy_prefill_interval
    }
    return SimpleNamespace(**args_dict)

#1 manange vllm instance
def execute_shell_cmd(cmd: str, env: Dict[str, str] = None) -> subprocess.Popen:
    """执行 shell 命令并返回进程对象"""
    return subprocess.Popen(
        cmd,
        shell=True,
        executable="/bin/bash",
        env=env or os.environ,
        stdout=subprocess.DEVNULL,
        stderr=subprocess.STDOUT,
        preexec_fn=os.setsid  
    )


def kill_process_group(pid: int):
    """终止整个进程组"""

    try:
        gid = os.getpgid(pid)
        try:
            os.killpg(gid, signal.SIGTERM) 
            time.sleep(2) 
            logger.debug(f"kill -15 {pid} success")
        except ProcessLookupError:
            logger.debug(f"kill -15 {pid} failed")
        os.killpg(gid, signal.SIGKILL) 
        logger.debug(f"kill -9 {pid} success")
    except ProcessLookupError:
        logger.debug(f"kill -9 {pid} failed")

def stop_all_vllm_services():
    print("终止vllm服务")
    active_pids = set()
    try:
        output = subprocess.check_output(
            "pgrep -f 'vllm.entrypoints.openai.api_server'",
            shell=True,
            text=True
        )
        active_pids = set(int(p) for p in output.split())
    except:
        pass    
    
    for pid in active_pids:
        kill_process_group(pid)
        time.sleep(5)

    for pids in server_list.values():
        for pid in pids:
            kill_process_group(pid)
            time.sleep(5)


def start_single_vllm_service(port: int, device: int, sys_args) -> List[int]:
    os.makedirs(sys_args.result_path, exist_ok=True)
    log_file = f"{sys_args.result_path}/cache_hit_ratio_{port}.log"
    cmd = (
        f"export ASCEND_RT_VISIBLE_DEVICES={device} && "
        f"python -u -m vllm.entrypoints.openai.api_server "
        f"--model {model_path} "
        f"--max-num-seqs=256 "
        f"--max-model-len={max_model_len} "
        f"--max-num-batched-tokens={max_model_len} "
        f"--dtype=float16 "
        f"--tensor-parallel-size=1 "
        f"--block-size=128 "
        f"--host=0.0.0.0 "
        f"--port={port} "
        f"--gpu-memory-utilization=0.9 "
        f"--trust-remote-code "
        f"--served-model-name {model_name} "
        f"--use-kvcache-store "
        f"> {log_file}"
    )
    print(f"start server:{cmd}") 
    process = execute_shell_cmd(cmd)
    time.sleep(60) 
    pids = []
    try:
        output = subprocess.check_output(
            f"pstree -p {process.pid} | grep -oP 'python\(\\K\\d+'",
            shell=True,
            text=True
        )
        pids = [int(pid) for pid in output.strip().split()]
    except Exception as e:
        print(f"fail: {str(e)}")
    
    return pids

def _start_all_vllm_services_helper(sys_args):
    global server_list
    server_list.clear()
    
    for port, device in zip(ports, devices):
        pids = start_single_vllm_service(port, device, sys_args)
        if pids:
            server_list[port] = pids
            print(f" {port} success, PID: {pids}")
        else:
            print(f" {port} failed to start.")

def vllm_services_health_check(result_path, sync_log_enable) -> bool:
    os.makedirs(result_path, exist_ok=True)
    log_file = f"{result_path}/services_health.log"

    all_valid = True
    
    # 物理检查进程状态
    current_pids = set()
    for pids in server_list.values():
        current_pids.update(pids)
    
    active_pids = set()
    try:
        output = subprocess.check_output(
            "pgrep -f 'vllm.entrypoints.openai.api_server'",
            shell=True,
            text=True
        )
        active_pids = set(int(p) for p in output.split())
    except:
        pass
    
    for port, pids in server_list.items():
        valid_pids = [pid for pid in pids if pid in active_pids]
        server_list[port] = valid_pids
        
        if len(valid_pids) < 1:
            all_valid = False
            
    return all_valid

def start_all_vllm_services(sys_args):
    max_retries = 3
    for attempt in range(max_retries):
        stop_all_vllm_services()
        time.sleep(5)
        _start_all_vllm_services_helper(sys_args)
        time.sleep(120)
        if vllm_services_health_check(sys_args.result_path, False):
            print("health_check:", server_list)
            return True
    return False

#2 manage client and proxy
async def _async_worker(launcher, result_path):
    try:
        await launcher.run()
        while await launcher.is_request_active():
            if  not vllm_services_health_check(result_path, True):
                break

            await asyncio.sleep(60)
        
        logger.debug(f"launcher.is_request_active()==false")
            
    finally:
        logger.debug("client 已停止")

def single_experiment(global_scheduler_type: str, qps: float, sys_args):
    logger.debug(f"\n {sys_args.global_scheduler_type} {sys_args.balance_type}/{sys_args.request_generate_qps}")
    
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    
    try:
        launcher = SystemLauncher(sys_args)

        loop.run_until_complete(launcher.initialize(sys_args))

        loop.run_until_complete(_async_worker(launcher, sys_args.result_path))
        
    finally:
        loop.close()
        logger.debug(f"=== stop {sys_args.global_scheduler_type} {sys_args.balance_type}/{sys_args.request_generate_qps} ===")

def main():

    for request_num in request_num_list:
        for qps in qps_list:
            for global_scheduler_type in global_scheduler_type_list:
                for balance_type in balance_type_list:
                    global_scheduler_dir = ""
                    if global_scheduler_type != "double_hash":
                        if balance_type == "": 
                            global_scheduler_dir = global_scheduler_type
                        else:
                            continue
                    elif global_scheduler_type == "double_hash":
                        if balance_type == "":
                            continue
                        else:
                            if balance_type in ["dh_least_loaded","dh_cache_affinity", "dh_min_ttft"]:
                                global_scheduler_dir = f'p_dh_{balance_type}-extend-{dh_extend_replica}'
                            elif balance_type in ["ttft_slo", "ttft_slo_aggresive", "ttft_avg","no_balance"]:
                                global_scheduler_dir = f'p_dh_{balance_type}-{dh_first_balance_ttft_thredhold}-{dh_rebalance_thredhold}-{dh_rebalance_waiting_latency_thredhold}-{dh_extend_replica}'
                            elif balance_type in ["nb_cost1", "rb_cost1","rb_cost1_aggresive","rb_cost1_avg"]:
                                global_scheduler_dir = f'p_dh_{balance_type}-{dh_recompute_punish_ratio}-{dh_first_balance_ttft_thredhold}-{dh_rebalance_thredhold}-{dh_rebalance_waiting_latency_thredhold}-{dh_extend_replica}'
                                print(global_scheduler_dir)
                    abnormal_requests_num = 0
                    result_path = ""
                    request_dataset_file = ""
                    if process_dataset_online:
                        dataset_type = f"conversatio-toolagent-online-{ct_ratio}" 
                    else:
                        dataset_type = f"conversatio-toolagent-offline-{ct_ratio}" 
                    if dataset_type == "simulated-mooncake/conversation":
                        result_path = os.path.join(
                            workspace_path,
                            'result',
                            model_name,
                            'simulated-mooncake',
                            'conversation',
                            f'req-{request_num}-cache-{replica_dram}G',
                            f'replica-{replica_num}-qps-{qps}',
                            f'share-{request_share_ratio}-abnormal-{abnormal_requests_rario}-warmup{warm_up_requests_num}',
                            f'p{request_first_prefill_len}-d{request_decode_len}-r{request_round}',
                            global_scheduler_dir
                        )
                        request_dataset_file = os.path.join(
                            workspace_path,
                            'dataset',
                            'simulated-mooncake',
                            'conversation',
                            f'share-{request_share_ratio}',
                            f'req-{request_num}-p{request_first_prefill_len}-d{request_decode_len}-r{request_round}.txt'
                        )
                    else:
                        result_path = os.path.join(
                            workspace_path,
                            'result',
                            model_name,
                            dataset_type,
                            f'req{request_num}-warm{warm_up_requests_num}-cache{replica_dram}G',
                            f'replica{replica_num}-qps{qps}',
                            global_scheduler_dir
                        )

                        request_dataset_file = os.path.join(
                            workspace_path,
                            'dataset',
                            'mooncake',
                            request_dataset_file_name
                        )

                        request_dataset_dir = os.path.join(
                            workspace_path,
                            'dataset',
                            'mooncake'
                        )                        

                    sys_args = build_namespace(global_scheduler_type, qps, 
                            model, model_path, model_name, 
                            replica_num, replica_dram,
                            request_num, request_share_ratio,
                            request_first_prefill_len, request_decode_len,
                            request_round, request_active_timeout,
                            dataset_type, request_dataset_file, 
                            request_dataset_dir, process_dataset_online,
                            balance_type, ttft_slo, replica_slo_budget,
                            dh_recompute_punish_ratio,dh_rebalance_thredhold,
                            dh_rebalance_waiting_latency_thredhold,dh_extend_replica,
                            result_path
                            ) 
                    # 0. stop vllm servers
                    if stop_service_enable:
                        logger.debug(f"stop_vllm_services")
                        stop_all_vllm_services()
                        return

                    # 1. launch vllm server
                    if start_service_enable:
                        if not start_all_vllm_services(sys_args):
                            stop_all_vllm_services()
                            time.sleep(5)         
                            continue

                    # 2. launch proxy, client
                    if start_proxy_enable:
                        p = Process(target=single_experiment, args=(global_scheduler_type, qps, sys_args))
                        p.start()
                        p.join() 
                        logger.debug(f" {global_scheduler_dir}/{qps}")

                        # 3. stop vllm servers
                        stop_all_vllm_services()
                        time.sleep(5)
                        logger.debug(f"stop_vllm_services, {global_scheduler_dir}/{qps}")
                    
if __name__ == "__main__":
    main()