import torch
import zmq
import pickle
import torch.distributed as dist
import traceback
import time
import random
import json
from collections import deque
from logger import logger
from typing import List
from functools import reduce

from gllm.input_data import InputData
from gllm.sequence import Sequence
from gllm.model_runner import ModelRunner
from gllm.dist_utils import init_dist, send_pp_data, send_tensor,recv_pp_data,recv_tensor,get_pp_used_idx_layers,get_pp_load_layers
from gllm.scheduler import SchedulerOutput
from gllm.utils import make_socket
from gllm.memory_manager import PrefixMemoryManager
import math
ADJUST_LAYERS = 0

# 430 profile splitwise_conv rate_18 32b
def predict_worker_time(prefill, decode,memory_util):     return 11.238080      + 0.075143 * prefill      + 0.140392 * decode
def predict_sample_time(prefill, decode,memory_util):     return 1.808118      + 0.044461 * decode

# 14b
# def predict_worker_time(prefill, decode,memory_util):     return 5.994919      + 0.029744 * prefill      + 0.110212 * decode # 14b
# def predict_sample_time(prefill, decode,memory_util):     return 1.795752      + 0.044437 * decode # 14b

# Used with PipeAsyncLLM
class Worker:
    
    def __init__(self, model_runner:ModelRunner, pp_rank, pp_size, 
                 master_addr, master_port, schedule_ipc_path, layer_adjust_ipc_path,layer_reset_ipc_path,output_ipc_path, token_ipc_path,mp_alive,enable_adjust_layers=False):
        self.model_runner = model_runner
        self.pp_rank = pp_rank # pp rank
        self.pp_size = pp_size # pp size
        self.master_addr = master_addr
        self.master_port = master_port
        self.schedule_ipc_path = schedule_ipc_path
        self.output_ipc_path = output_ipc_path
        self.token_ipc_path = token_ipc_path
        self.mp_alive = mp_alive
        self.enable_adjust_layers = enable_adjust_layers
        self.input_data_cnt = 0
        if self.enable_adjust_layers:
            self.layer_adjust_ipc_path = layer_adjust_ipc_path
            self.layer_reset_ipc_path = layer_reset_ipc_path
            self.is_adjusted = False
            self.target_adjust_layers = -1
            self.adjust_flag = 0
            self.recv_start_layer_idx = 1000000
            self.recv_k_future_list = []
            self.recv_v_future_list = []
            self.last_input_data_id = -1
            if self.pp_rank == 0:
                self.ratio_queue = deque([0] * 30, maxlen=30)
                self.ratio = 0
                self.count = 0
            


        ## debug
        self.last_time = 0
                
    
    def get_pp_next_rank(self):
        # return device_rank of next pp_rank
        return (self.pp_rank + 1) % self.pp_size
        
    def get_pp_last_rank(self):
        # return device_rank of last pp_rank
        return (self.pp_rank - 1 + self.pp_size) % self.pp_size
    
    def init(self):
        init_dist(self.pp_size, self.pp_rank, self.pp_size, self.pp_rank, self.master_addr, self.master_port)
        torch.cuda.set_device(f'cuda:{self.pp_rank}')
        zmq_ctx = zmq.Context()

        if self.enable_adjust_layers:
            self.layer_reset_socket = make_socket(zmq_ctx,  f'{self.layer_reset_ipc_path}_{self.pp_rank}', zmq.PULL)

        if self.pp_rank == 0:
            # main process => rank 0
            self.schedule_socket = make_socket(zmq_ctx, self.schedule_ipc_path, zmq.PULL) 
            if self.enable_adjust_layers:
                self.layer_adjust_socket = make_socket(zmq_ctx, self.layer_adjust_ipc_path, zmq.PULL)
            # rank 0 => main process
            self.output_socket = make_socket(zmq_ctx, self.output_ipc_path, zmq.PUSH)
            # seqs to schedule
            self.seqs_to_prefill: deque[Sequence] = deque()
            self.seqs_to_decode: deque[Sequence] = deque()
            # running batch 
            self.batch_running = deque()
            self.log_time = 0
            # preempt seqs
            self.num_preempt_seqs = 0
            self.log_num_preempt_seqs = 0
            # num wait tokens
            self.num_wait_tokens = 0
            # rank 0 => other ranks : batched seqs
            self.gpu_schedule_socket = []
            self.layer_adjust_socket = []
            for i in range(1,self.pp_size):
                self.gpu_schedule_socket.append(make_socket(zmq_ctx, f'{self.schedule_ipc_path}_{i}',zmq.PUSH))
                if self.enable_adjust_layers:
                    self.layer_adjust_socket.append(make_socket(zmq_ctx, f'{self.layer_adjust_ipc_path}_{i}',zmq.PUSH))
            if self.pp_size != 1:
                # last rank => rank 0 : next tokens
                self.token_socket = make_socket(zmq_ctx, self.token_ipc_path, zmq.PULL)
        else:
            # rank 0 => other ranks : batched seqs
            self.gpu_schedule_socket = make_socket(zmq_ctx, f'{self.schedule_ipc_path}_{self.pp_rank}', zmq.PULL) # 其他流水线阶段，负责接收调度来的请求
            if self.enable_adjust_layers:
                self.layer_adjust_socket = make_socket(zmq_ctx, f'{self.layer_adjust_ipc_path}_{self.pp_rank}', zmq.PULL) # 其他流水线阶段，负责接收调整layer的请求
            # Input data for each rank except 0 
            self.schedule_queue = deque()
            # Input data and intermediate data for rank except 0
            self.run_queue = deque()
        if self.pp_rank == self.pp_size - 1 and self.pp_size != 1:
            # last rank => rank 0 : next tokens
            self.token_socket = make_socket(zmq_ctx, self.token_ipc_path, zmq.PUSH) # 最后一个流水线阶段，还要负责发送tokens
        self.model_runner.init()
        self.dtype = self.model_runner.memory_manager.dtype
        self.hidden_size = self.model_runner.model_loader.hidden_size
        self.ret_residual = self.model_runner.model.ret_residual
        self.max_decode_seqs = self.model_runner.max_decode_seqs
        self.max_batch_tokens = self.model_runner.max_batch_tokens
        self.page_size = self.model_runner.page_size
        self.ratio_threshold_free_pages = self.model_runner.ratio_threshold_free_pages
        self.num_threshold_free_pages = int(
            self.model_runner.ratio_threshold_free_pages * self.model_runner.memory_manager.get_num_free_pages())
        
        self.mp_alive[self.pp_rank] = 1 # 成功启动对应gpu上的worker
    
    def get_num_free_pages(self):
        return self.model_runner.memory_manager.get_num_free_pages()
    

    def _adjust_layer(self,target_adjust_layers:int,input_data:InputData):
        cur_adjust_layers = self.model_runner.model.get_adjust_layers()
        cur_start_used_layer_id,cur_end_used_layer_id = get_pp_used_idx_layers(self.model_runner.model.model.num_hidden_layers,cur_adjust_layers)
        target_start_used_layer_id,target_end_used_layer_id = get_pp_used_idx_layers(self.model_runner.model.model.num_hidden_layers,target_adjust_layers)
        load_start_layer_idx,load_end_layer_idx = get_pp_load_layers(self.model_runner.model.model.num_hidden_layers)

        if cur_adjust_layers > target_adjust_layers: # 往下调整 recv前面层，send后面层
            input_data.adjust_flag = -1

            send_layers_num = cur_end_used_layer_id - target_end_used_layer_id
            send_offset = target_end_used_layer_id - load_start_layer_idx
            for i in range(0,send_layers_num):
                send_k_cache = self.model_runner.memory_manager.segment.k_cache[send_offset+i]
                send_v_cache = self.model_runner.memory_manager.segment.v_cache[send_offset+i]
                send_tensor(self.get_pp_next_rank(),send_k_cache)
                send_tensor(self.get_pp_next_rank(),send_v_cache)

            recv_layers_num = cur_start_used_layer_id - target_start_used_layer_id
            recv_offset = target_start_used_layer_id - load_start_layer_idx
            recv_k_future_list = []
            recv_v_future_list = []
            for i in range(0,recv_layers_num):
                recv_k_cache = self.model_runner.memory_manager.segment.k_cache[recv_offset+i]
                recv_v_cache = self.model_runner.memory_manager.segment.v_cache[recv_offset+i]
                recv_k_future = recv_tensor(self.get_pp_last_rank(),recv_k_cache) 
                recv_v_future = recv_tensor(self.get_pp_last_rank(),recv_v_cache) 
                recv_k_future_list.append(recv_k_future)
                recv_v_future_list.append(recv_v_future)
            
            for i in range(0,recv_layers_num):
                recv_k_future_list[i].wait()
                recv_v_future_list[i].wait()


        else: # 往上调整 send前面层，recv后面层
            input_data.adjust_flag = 1
            
            # send在上一个input_data的forward的时候就发送

            # recv 在step的时候接收
            input_data.recv_start_layer_idx = cur_end_used_layer_id
            recv_layers_num = target_end_used_layer_id - cur_end_used_layer_id # 具体接收等后一个GPU执行完再接收
            # 从后一个gpu获取kvcache
            recv_offset = cur_end_used_layer_id - load_start_layer_idx
            for i in range(0,recv_layers_num):
                recv_k_cache = self.model_runner.memory_manager.segment.k_cache[recv_offset+i]
                recv_v_cache = self.model_runner.memory_manager.segment.v_cache[recv_offset+i]
                recv_k_future = recv_tensor(self.get_pp_next_rank(),recv_k_cache)
                recv_v_future = recv_tensor(self.get_pp_next_rank(),recv_v_cache)
                input_data.recv_k_future_queue.append(recv_k_future)
                input_data.recv_v_future_queue.append(recv_v_future)

        self.model_runner.model.adjust_layer(target_adjust_layers)

    def _adjust_layer_without_inputdata(self,target_adjust_layers:int):
        cur_adjust_layers = self.model_runner.model.get_adjust_layers()
        cur_start_used_layer_id,cur_end_used_layer_id = get_pp_used_idx_layers(self.model_runner.model.model.num_hidden_layers,cur_adjust_layers)
        target_start_used_layer_id,target_end_used_layer_id = get_pp_used_idx_layers(self.model_runner.model.model.num_hidden_layers,target_adjust_layers)
        load_start_layer_idx,load_end_layer_idx = get_pp_load_layers(self.model_runner.model.model.num_hidden_layers)

        if cur_adjust_layers > target_adjust_layers: # 往下调整 recv前面层，send后面层
            self.adjust_flag = -1
            send_layers_num = cur_end_used_layer_id - target_end_used_layer_id
            send_offset = target_end_used_layer_id - load_start_layer_idx
            for i in range(0,send_layers_num):
                send_k_cache = self.model_runner.memory_manager.segment.k_cache[send_offset+i]
                send_v_cache = self.model_runner.memory_manager.segment.v_cache[send_offset+i]
                send_tensor(self.get_pp_next_rank(),send_k_cache)
                send_tensor(self.get_pp_next_rank(),send_v_cache)

            recv_layers_num = cur_start_used_layer_id - target_start_used_layer_id
            recv_offset = target_start_used_layer_id - load_start_layer_idx
            recv_k_future_list = []
            recv_v_future_list = []
            for i in range(0,recv_layers_num):
                recv_k_cache = self.model_runner.memory_manager.segment.k_cache[recv_offset+i]
                recv_v_cache = self.model_runner.memory_manager.segment.v_cache[recv_offset+i]
                recv_k_future = recv_tensor(self.get_pp_last_rank(),recv_k_cache) 
                recv_v_future = recv_tensor(self.get_pp_last_rank(),recv_v_cache) 
                recv_k_future_list.append(recv_k_future)
                recv_v_future_list.append(recv_v_future)
            
            for i in range(0,recv_layers_num):
                recv_k_future_list[i].wait()
                recv_v_future_list[i].wait()


        else: # 往上调整 send前面层，recv后面层
            self.adjust_flag = 1
            # send在上一个input_data的forward的时候就发送
            self.recv_start_layer_idx = cur_end_used_layer_id
            # recv 在step的时候接收
            recv_layers_num = target_end_used_layer_id - cur_end_used_layer_id # 具体接收等后一个GPU执行完再接收
            # 从后一个gpu获取kvcache
            recv_offset = cur_end_used_layer_id - load_start_layer_idx
            for i in range(0,recv_layers_num):
                recv_k_cache = self.model_runner.memory_manager.segment.k_cache[recv_offset+i]
                recv_v_cache = self.model_runner.memory_manager.segment.v_cache[recv_offset+i]
                recv_k_future = recv_tensor(self.get_pp_next_rank(),recv_k_cache)
                recv_v_future = recv_tensor(self.get_pp_next_rank(),recv_v_cache)
                self.recv_k_future_list.append(recv_k_future)
                self.recv_v_future_list.append(recv_v_future)
            
            # for i in range(0,recv_layers_num):
            #     recv_k_future_list[i].wait()
            #     recv_v_future_list[i].wait()

        self.model_runner.model.adjust_layer(target_adjust_layers)

    # rank except 0 
    def run(self):
        if time.time() - self.last_time > 5:
            logger.info(f"Worker {self.pp_rank} len(self.run_queue): {len(self.run_queue)}")
            self.last_time = time.time()
        
        if self.enable_adjust_layers and self.layer_reset_socket.poll(timeout=0) != 0: # 查看有没有调整layer的信号
            recv_bytes = self.layer_reset_socket.recv(copy=False)
            self.model_runner.model.adjust_layer(0)
            self.count = 0
            self.ratio = 0

        if self.enable_adjust_layers and self.layer_adjust_socket.poll(timeout=0) != 0: # 查看有没有调整layer的信号
            recv_bytes = self.layer_adjust_socket.recv(copy=False)
            target_adjust_layers,input_data_id = pickle.loads(recv_bytes)
            self.last_input_data_id = input_data_id # 调整前最后一个input_data的id
            self.target_adjust_layers = target_adjust_layers

        # model forward
        if len(self.run_queue) != 0:
            hidden_states = None
            residual = None
            input_data,intermediate_data = self.run_queue.popleft()


            if len(intermediate_data) == 4:
                if not (intermediate_data[0].is_completed() and intermediate_data[1].is_completed()):
                    self.run_queue.appendleft((input_data,intermediate_data))
                    return
                else:
                    hidden_states, residual = intermediate_data[2], intermediate_data[3]
            elif len(intermediate_data) == 2:
                if not intermediate_data[0].is_completed():
                    self.run_queue.appendleft((input_data,intermediate_data))
                    return
                else:
                    hidden_states = intermediate_data[1]
            else:
                assert 0


            if self.enable_adjust_layers and (self.adjust_flag != 0):
                input_data.adjust_flag = self.adjust_flag
                input_data.recv_start_layer_idx = self.recv_start_layer_idx
                input_data.recv_k_future_queue.extend(self.recv_k_future_list)
                input_data.recv_v_future_queue.extend(self.recv_v_future_list)
                # logger.info(f"Worker {self.pp_rank} adjust to {self.target_adjust_layers} in input_data_id: {input_data.input_data_id}")
                self.reset()
                
            if self.enable_adjust_layers and input_data.input_data_id == self.last_input_data_id:
                # 在step之前记录下调整信号，这样step的时候可以先把前几层计算完就发出去，在向上调整的时候
                # 判断是什么调整
                cur_adjust_layers = self.model_runner.model.get_adjust_layers() # 目前调整的layers数量
                cur_start_used_layer_id,cur_end_used_layer_id = get_pp_used_idx_layers(self.model_runner.model.model.num_hidden_layers,cur_adjust_layers)
                target_start_used_layer_id,target_end_used_layer_id = get_pp_used_idx_layers(self.model_runner.model.model.num_hidden_layers,self.target_adjust_layers)
                load_start_layer_idx,load_end_layer_idx = get_pp_load_layers(self.model_runner.model.model.num_hidden_layers)

                if cur_adjust_layers > self.target_adjust_layers: # 往下调整 recv前面层，send后面层
                    input_data.adjust_flag = -1
                else: # 往上调整 send前面层，recv后面层
                    input_data.adjust_flag = 1
                    input_data.send_end_layer_idx = target_start_used_layer_id
                    send_layers_num = target_start_used_layer_id - cur_start_used_layer_id
                    send_offset = cur_start_used_layer_id - load_start_layer_idx
                    send_k_cache = self.model_runner.memory_manager.segment.k_cache[send_offset:send_offset+send_layers_num]
                    send_v_cache = self.model_runner.memory_manager.segment.v_cache[send_offset:send_offset+send_layers_num]
                    # 向前一个gpu发送kvcache，在step过程发送
                    for i in range(0,send_layers_num):
                        input_data.send_k_cache_queue.append(send_k_cache[i])
                        input_data.send_v_cache_queue.append(send_v_cache[i])
                
                self.is_adjusted = True


            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)
            start_event.record()
            # logger.info(f"input_data_id: {input_data.input_data_id} Worker {self.pp_rank} step_once start at {t1}")
            output = self.model_runner.step_once(input_data,hidden_states,residual)
            end_event.record()
            torch.cuda.synchronize()  # 等待事件完成
            elapsed_time_ms = start_event.elapsed_time(end_event) 
            logger.info(f'Worker rank {self.pp_rank} step time: {elapsed_time_ms}')
            # logger.info(f"input_data_id: {input_data.input_data_id} Worker {self.pp_rank} step_once end at {t2},step time: {t2-t1}")



            if self.pp_rank == self.pp_size - 1:
                assert type(output) == list
                token_bytes = pickle.dumps(output)
                self.token_socket.send(token_bytes, copy=False)
            else:
                send_pp_data(output, self.get_pp_next_rank())

            if self.enable_adjust_layers and self.is_adjusted:
                # torch.cuda.synchronize()
                t1 = time.time()
                self._adjust_layer_without_inputdata(self.target_adjust_layers)
                self.is_adjusted = False     
                # torch.cuda.synchronize()
                t2 = time.time()
                # logger.info(f"Worker {self.pp_rank} adjust layer to {self.target_adjust_layers},self.adjust_flag: {self.adjust_flag},cost {t2-t1}s")      


        
        # recv schedule seqs
        if self.gpu_schedule_socket.poll(timeout=0) != 0: # 返回非0值表示有新消息可接收
            recv_bytes = self.gpu_schedule_socket.recv(copy=False)
            seqs,input_data_id = pickle.loads(recv_bytes)
            self.schedule_queue.append(InputData(seqs,self.model_runner.memory_manager,input_data_id))
        if len(self.schedule_queue) != 0:
            # recv intermediate data
            input_data = self.schedule_queue.popleft()
            intermediate_data = recv_pp_data(
                self.get_pp_last_rank(), self.dtype, 
                [input_data.tokens.shape[0],self.hidden_size], self.ret_residual)
            self.run_queue.append((input_data,intermediate_data))
    
    # rank 0
    def get_num_decode_seqs(self):
        num_decode_seqs = len(self.seqs_to_decode) + reduce(lambda x,y: x+len(y),self.batch_running,0)
        return num_decode_seqs
    
    # rank 0
    def update_num_wait_tokens(self):
        self.num_wait_tokens = reduce(
            lambda x,y: x + len(y.token_ids) - y.computed_token_num,self.seqs_to_prefill,0)
    
    # rank 0: check if preempt seqs 
    def check_preempt(self,num_decode_tokens):
        preempt_seqs = []
        while self.model_runner.memory_manager.get_num_free_pages() < num_decode_tokens and len(self.seqs_to_decode) != 0:
            seq_to_preempt = self.seqs_to_decode.popleft()
            self.model_runner.memory_manager.free(seq_to_preempt)
            seq_to_preempt.preempt()
            preempt_seqs.append(seq_to_preempt)
            
        self.seqs_to_prefill.extendleft(preempt_seqs)
        
        self.num_preempt_seqs += len(preempt_seqs)
        if self.num_preempt_seqs - self.log_num_preempt_seqs >= 10:
            self.log_num_preempt_seqs = self.num_preempt_seqs
            logger.warning(f'#Preempted seqs: {self.num_preempt_seqs}')
            logger.warning('Try increase --ratio-free-pages or the performance is poor!')
    
    def schedule_naive(self):
        schedule_prefill_seqs = []
        schedule_decode_seqs = []
        
        num_tokens_budget = self.max_batch_tokens
        
        self.check_preempt(min(num_tokens_budget,len(self.seqs_to_decode)))
        # decode
        for _ in range(num_tokens_budget):
            if len(self.seqs_to_decode) == 0:
                break
            seq = self.seqs_to_decode.pop()
            seq.to_compute_token_num = 1
            schedule_decode_seqs.append(seq)
            
            
        self.model_runner.memory_manager.pre_allocate_page(schedule_decode_seqs)
        
        num_tokens_budget -= len(schedule_decode_seqs)
        
        # prefill
        prefill_batched_token_nums = 0
        while len(self.seqs_to_prefill) != 0 and num_tokens_budget != 0:
            seq = self.seqs_to_prefill.popleft()
            self.model_runner.memory_manager.pre_allocate_page([seq])
            if len(seq.token_ids)-seq.computed_token_num <= num_tokens_budget:
                seq.to_compute_token_num = len(seq.token_ids) - seq.computed_token_num
                prefill_batched_token_nums += seq.to_compute_token_num
                num_tokens_budget -= seq.to_compute_token_num
            else:
                prefill_batched_token_nums += num_tokens_budget
                seq.to_compute_token_num = num_tokens_budget
                num_tokens_budget = 0
            schedule_prefill_seqs.append(seq)

        if time.time()-self.log_time > 1:
            self.log_time = time.time()
            log_info = '#wait: %4d #run: %4d #prefill: %4d #decode: %4d memory_util: %5s %%' % (
                            len(self.seqs_to_prefill),
                            self.get_num_decode_seqs(),
                            prefill_batched_token_nums,
                            len(schedule_decode_seqs),
                            '%.2f' % self.model_runner.memory_manager.get_memory_util())
            if isinstance(self.model_runner.memory_manager, PrefixMemoryManager):
                log_info += ' cache_hit_rate: %5s %%' % ('%.2f' % self.model_runner.memory_manager.get_cache_hit_rate())
                logger.info(log_info)
            else:
                logger.info(log_info)
        return schedule_prefill_seqs + schedule_decode_seqs
    
    # rank 0: PP schedule 
    def schedule(self):
        
        schedule_prefill_seqs = []
        schedule_decode_seqs = []
        
        # prefill
        prefill_token_budget = self.page_size * max(self.get_num_free_pages()-self.num_threshold_free_pages,0)
        if self.pp_size > 1 and prefill_token_budget != 0:
            self.update_num_wait_tokens()
            free_ratio = self.model_runner.memory_manager.get_memory_free()
            # a = ratio_threshold_free_pages
            # free_ratio in [1,a] | prefill_ratio in [1,0]
            prefill_ratio = (free_ratio - self.ratio_threshold_free_pages) / (1-self.ratio_threshold_free_pages)
            prefill_ratio = max(prefill_ratio,0)
            prefill_token_budget = min(
                round(prefill_ratio * self.max_batch_tokens),
                prefill_token_budget)
            prefill_token_budget = min(max(self.num_wait_tokens//8,32),prefill_token_budget)
        else:
            prefill_token_budget = min(self.max_batch_tokens, prefill_token_budget)
        prefill_batched_token_nums = 0
        while len(self.seqs_to_prefill) != 0 and prefill_token_budget != 0:
            seq = self.seqs_to_prefill.popleft()
            if isinstance(self.model_runner.memory_manager, PrefixMemoryManager) and seq.computed_token_num == 0:
                self.model_runner.memory_manager.pre_allocate_computed_page([seq])
            if len(seq.token_ids)-seq.computed_token_num <= prefill_token_budget:
                seq.to_compute_token_num = len(seq.token_ids) - seq.computed_token_num
                prefill_batched_token_nums += seq.to_compute_token_num
                prefill_token_budget -= seq.to_compute_token_num
            else:
                prefill_batched_token_nums += prefill_token_budget
                seq.to_compute_token_num = prefill_token_budget
                prefill_token_budget = 0
            schedule_prefill_seqs.append(seq)

        self.model_runner.memory_manager.pre_allocate_page(schedule_prefill_seqs) # 为每个req没有computed的token分配页号写入seq.page_table
        
        # decode
        num_total_decode_seqs = self.get_num_decode_seqs()
        if num_total_decode_seqs < self.pp_size:
            decode_token_budget = num_total_decode_seqs
        else:
            # here we add num_total_decode_seqs to random.randint(0,self.pp_size-1))
            # because we want to solve the situation when #seqs=5 pp_size=4
            decode_token_budget = (num_total_decode_seqs + random.randint(0,self.pp_size-1)) // self.pp_size
        
        self.check_preempt(decode_token_budget)
        
        for _ in range(decode_token_budget):
            if len(self.seqs_to_decode) == 0:
                break
            seq = self.seqs_to_decode.popleft()
            seq.to_compute_token_num = 1
            schedule_decode_seqs.append(seq)
            
        self.model_runner.memory_manager.pre_allocate_page(schedule_decode_seqs)

        if self.pp_rank == 0:
            self.prefill_tokens = prefill_batched_token_nums
            self.decode_tokens = len(schedule_decode_seqs)
            if self.enable_adjust_layers and (len(self.seqs_to_decode) + len(self.seqs_to_prefill) != 0):
                if prefill_batched_token_nums == 0 and len(schedule_decode_seqs) < 5:
                    self.count = 0
                else:

                    forward_time = predict_worker_time(prefill_batched_token_nums, len(schedule_decode_seqs), self.model_runner.memory_manager.get_memory_util())
                    sample_time = predict_sample_time(prefill_batched_token_nums, len(schedule_decode_seqs), self.model_runner.memory_manager.get_memory_util())
                    each_layer_forward_time = forward_time / self.model_runner.model.get_num_avg_layers()
                    ratio = min(round(sample_time / each_layer_forward_time),self.pp_size-1)
                    # ratio = min(math.ceil(sample_time / each_layer_forward_time - 1),self.pp_size - 1)
                    # self.ratio_queue.append(ratio)
                    # new_ratio = round(sum(self.ratio_queue) / len(self.ratio_queue))
                    if ratio  == self.ratio:
                        self.count += 1
                    else:
                        self.count = 1
                    self.ratio = ratio
                    logger.info(f"predict adjust state:{ratio}")
                # self.ratio = round(sum(self.ratio_queue) / len(self.ratio_queue))

        # TODO predict
        if time.time()-self.log_time > 0:
            self.log_time = time.time()
            log_info = '#wait: %4d/%8d #run: %4d #prefill: %4d #decode: %4d memory_util: %5s %%' % (
                            len(self.seqs_to_prefill),
                            self.num_wait_tokens,
                            num_total_decode_seqs,
                            prefill_batched_token_nums,
                            len(schedule_decode_seqs),
                            '%.2f' % self.model_runner.memory_manager.get_memory_util())
            if isinstance(self.model_runner.memory_manager, PrefixMemoryManager):
                log_info += ' cache_hit_rate: %5s %%' % ('%.2f' % self.model_runner.memory_manager.get_cache_hit_rate())
                logger.info(log_info)
            else:
                logger.info(log_info)

        # if len(schedule_decode_seqs+schedule_prefill_seqs) !=0:
        #     prefill_to_compute_tokens = []
        #     prefill_computed_tokens = []
        #     decode_computed_tokens = []
        #     for seq in schedule_prefill_seqs:
        #         prefill_to_compute_tokens.append(seq.to_compute_token_num)
        #         prefill_computed_tokens.append(seq.computed_token_num)
        #     for seq in schedule_decode_seqs:
        #         decode_computed_tokens.append(seq.computed_token_num)
        #     log_entry = {
        #         "prefill_to_compute_tokens": prefill_to_compute_tokens,
        #         "prefill_computed_tokens": prefill_computed_tokens,
        #         "decode_computed_tokens": decode_computed_tokens,
        #     }
        #     logger.info(f"[perf_trace] {json.dumps(log_entry)}")


        # with open('log','a') as f:
        #     f.write(f'{prefill_batched_token_nums} {len(schedule_decode_seqs)}\n')
        return schedule_prefill_seqs + schedule_decode_seqs

    def reset(self):
        self.adjust_flag = 0
        self.recv_start_layer_idx = 10000
        self.recv_k_future_list = []
        self.recv_v_future_list = []
    
    # rank 0
    def schedule_run(self):
        if time.time() - self.last_time > 5:
            logger.info(f"Worker {self.pp_rank} len(self.batch_running): {len(self.batch_running)}")
            logger.info(f"Worker {self.pp_rank} seqs_to_decode: {len(self.seqs_to_decode)}, seqs_to_prefill: {len(self.seqs_to_prefill)}")
            self.last_time = time.time()

        if self.enable_adjust_layers and self.layer_reset_socket.poll(timeout=0) != 0: # 查看有没有调整layer的信号
            recv_bytes = self.layer_reset_socket.recv(copy=False)
            self.model_runner.model.adjust_layer(0)
            self.count = 0
            self.ratio = 0


        output = None
        # 检查是否有新的调度消息
        if self.schedule_socket.poll(timeout=0) != 0:
            # 接收并反序列化调度输出
            recv_bytes = self.schedule_socket.recv(copy=False) # 从PipeAsyncLLM发送过来的调度请求
            schedulerOutput: SchedulerOutput = pickle.loads(recv_bytes)
            # 将新的序列添加到预填充队列
            self.seqs_to_prefill.extend(schedulerOutput.schedule_lists) # 等待Prefill的请求

        

        if len(self.seqs_to_decode) + len(self.seqs_to_prefill) != 0 and len(self.batch_running) < self.pp_size:
            schedule_seqs = self.schedule() # 处理的请求数量
            if len(schedule_seqs) != 0:
                input_data = InputData(schedule_seqs, self.model_runner.memory_manager,self.input_data_cnt)
                self.input_data_cnt += 1

                if self.pp_size > 1:
                    seqs_bytes = pickle.dumps((schedule_seqs,input_data.input_data_id))
                    for i in range(1,self.pp_size):
                        self.gpu_schedule_socket[i-1].send(seqs_bytes,copy=False)


                if self.enable_adjust_layers and (self.adjust_flag != 0):
                    input_data.adjust_flag = self.adjust_flag
                    input_data.recv_start_layer_idx = self.recv_start_layer_idx
                    input_data.recv_k_future_queue.extend(self.recv_k_future_list)
                    input_data.recv_v_future_queue.extend(self.recv_v_future_list)
                    logger.info(f"Worker {self.pp_rank} adjust to {self.target_adjust_layers} in input_data_id: {input_data.input_data_id} at {time.time()}")
                    self.reset()


                self.batch_running.append(schedule_seqs)
                start_event = torch.cuda.Event(enable_timing=True)
                end_event = torch.cuda.Event(enable_timing=True)
                start_event.record()
                # logger.info(f"input_data_id: {input_data.input_data_id} Worker {self.pp_rank} step_once start at {t1}")
                output = self.model_runner.step_once(input_data)
                end_event.record()
                torch.cuda.synchronize()  # 等待事件完成
                elapsed_time_ms = start_event.elapsed_time(end_event) 
                logger.info(f'Worker rank {self.pp_rank} step time: {elapsed_time_ms}')
                # logger.info(f"input_data_id: {input_data.input_data_id} Worker {self.pp_rank} step_once end at {t2}")


                
                if type(output) != list:
                    if self.enable_adjust_layers:
                        # 判断是否需要调整layer
                        cur_adjust_layers = self.model_runner.model.get_adjust_layers() # 目前调整的layers数量
                        target_adjust_layers = self.ratio
                        # flag = (abs(target_adjust_layers-cur_adjust_layers)>1 and self.count > 30) or (abs(target_adjust_layers-cur_adjust_layers)==1 and self.count > 150)
                        flag = (self.count == 25)
                        if (target_adjust_layers != cur_adjust_layers) and flag: # 需要调整    
                            # 发送调整layer的信号
                            self.last_input_data_id = input_data.input_data_id
                            for i in range(1,self.pp_size):
                                state_bytes = pickle.dumps((target_adjust_layers,input_data.input_data_id)) # 调整前最后一个input_data的id
                                self.layer_adjust_socket[i-1].send(state_bytes,copy=False)
                            self.is_adjusted = True
                            self.target_adjust_layers = target_adjust_layers

                    send_pp_data(output, self.get_pp_next_rank())
        

                    if self.enable_adjust_layers and self.is_adjusted:
                        # torch.cuda.synchronize()
                        t1 = time.time()
                        # logger.info(f"Worker {self.pp_rank} adjust to {self.target_adjust_layers} without input_data")
                        self._adjust_layer_without_inputdata(self.target_adjust_layers) # 再进行调整
                        self.is_adjusted = False    
                        # torch.cuda.synchronize()
                        t2 = time.time()
                        logger.info(f"Worker {self.pp_rank} adjust layer to {self.target_adjust_layers},self.adjust_flag: {self.adjust_flag} at {time.time()}")
                        # logger.info(f"Worker {self.pp_rank} adjust layer to {self.target_adjust_layers},self.adjust_flag: {self.adjust_flag},cost {t2-t1}s")
        return output
            
        
    # rank 0
    def process_output(self, output):
        next_tokens = None
        if isinstance(output,list) : # word_size == 1
            next_tokens = output
        elif self.pp_size != 1 and self.token_socket.poll(timeout=0) != 0: # recv tokens from last rank
            recv_bytes = self.token_socket.recv(copy=False)
            next_tokens = pickle.loads(recv_bytes)
        
        if next_tokens is not None:
            schedule_seqs:List[Sequence] = self.batch_running.popleft()
            assert len(next_tokens) == len(schedule_seqs)

            send_tokens = []
            schedulerOutput = SchedulerOutput([])
            
            for idx, seq in enumerate(schedule_seqs):
                seq.computed_token_num += seq.to_compute_token_num
                if seq.computed_prompt():
                    schedulerOutput.act_schedule_ids.append(seq.seq_id)
                    send_tokens.append(next_tokens[idx])
                    seq.token_ids.append(next_tokens[idx])
                if seq.is_finish():
                    schedulerOutput.free_ids.append(seq.seq_id)
                    self.model_runner.memory_manager.free(seq)
                elif seq.computed_prompt():
                    self.seqs_to_decode.appendleft(seq)
                else:
                    self.seqs_to_prefill.appendleft(seq)
            
            output_bytes = pickle.dumps((schedulerOutput, send_tokens))
            self.output_socket.send(output_bytes, copy=False)
 
def run_worker(worker: Worker):
    try:
        worker.init()
        logger.info(f'Worker {worker.pp_rank} init')
        while True:
            if worker.pp_rank == 0:
                output = worker.schedule_run()
                worker.process_output(output)   
            else:
                worker.run()
    except KeyboardInterrupt as e:
        logger.info(f'Worker {worker.pp_rank} exit')
        dist.destroy_process_group()
        worker.mp_alive[worker.pp_rank] = -1
    except Exception as e:
        logger.error(f'Worker {worker.pp_rank} \n{e}')
        traceback.print_exc()
        dist.destroy_process_group()
        worker.mp_alive[worker.pp_rank] = -1