import torch
import zmq
import pickle
import torch.distributed as dist
import traceback
import time
import random

from collections import deque
from logger import logger
from typing import List
from functools import reduce

from gllm.input_data import InputData
from gllm.sequence import Sequence
from gllm.model_runner import ModelRunner
from gllm.dist_utils import init_dist, send_pp_data, send_tensor,recv_pp_data,recv_tensor,get_pp_used_idx_layers,get_pp_load_layers
from gllm.scheduler import SchedulerOutput
from gllm.utils import make_socket
from gllm.memory_manager import PrefixMemoryManager
import math
ADJUST_LAYERS = 0

def predict_worker_time(prefill, decode, memory_util):
    return (0.010746 +           # 基础时间
            0.000082 * prefill + # prefill 时间系数
            0.000032 * decode +  # decode 时间系数
            0.000138 * memory_util) # 内存使用时间系数

def predict_sample_time(prefill, decode, memory_util):
    return (0.001907 +           # 基础时间
            0.000001 * prefill + # prefill 时间系数
            0.000041 * decode +  # decode 时间系数
            -0.000003 * memory_util) # 内存使用时间系数
# Used with PipeAsyncLLM
class Worker:
    
    def __init__(self, model_runner:ModelRunner, pp_rank, pp_size, 
                 master_addr, master_port, schedule_ipc_path, layer_adjust_ipc_path,output_ipc_path, token_ipc_path,mp_alive,enable_adjust_layers=False):
        self.model_runner = model_runner
        self.pp_rank = pp_rank # pp rank
        self.pp_size = pp_size # pp size
        self.master_addr = master_addr
        self.master_port = master_port
        self.schedule_ipc_path = schedule_ipc_path
        self.output_ipc_path = output_ipc_path
        self.token_ipc_path = token_ipc_path
        self.mp_alive = mp_alive
        self.enable_adjust_layers = enable_adjust_layers

        if self.enable_adjust_layers:
            self.layer_adjust_ipc_path = layer_adjust_ipc_path
            if self.pp_rank == 0:
                self.ratio = 0
                self.count = 0
    
    def get_pp_next_rank(self):
        # return device_rank of next pp_rank
        return (self.pp_rank + 1) % self.pp_size
        
    def get_pp_last_rank(self):
        # return device_rank of last pp_rank
        return (self.pp_rank - 1 + self.pp_size) % self.pp_size
    
    def init(self):
        init_dist(self.pp_size, self.pp_rank, self.pp_size, self.pp_rank, self.master_addr, self.master_port)
        torch.cuda.set_device(f'cuda:{self.pp_rank}')
        zmq_ctx = zmq.Context()
        if self.pp_rank == 0:
            # main process => rank 0
            self.schedule_socket = make_socket(zmq_ctx, self.schedule_ipc_path, zmq.PULL) 
            if self.enable_adjust_layers:
                self.layer_adjust_socket = make_socket(zmq_ctx, self.layer_adjust_ipc_path, zmq.PULL)
            # rank 0 => main process
            self.output_socket = make_socket(zmq_ctx, self.output_ipc_path, zmq.PUSH)
            # seqs to schedule
            self.seqs_to_prefill: deque[Sequence] = deque()
            self.seqs_to_decode: deque[Sequence] = deque()
            # running batch 
            self.batch_running = deque()
            self.log_time = 0
            # preempt seqs
            self.num_preempt_seqs = 0
            self.log_num_preempt_seqs = 0
            # num wait tokens
            self.num_wait_tokens = 0
            # rank 0 => other ranks : batched seqs
            self.gpu_schedule_socket = []
            self.layer_adjust_socket = []
            for i in range(1,self.pp_size):
                self.gpu_schedule_socket.append(make_socket(zmq_ctx, f'{self.schedule_ipc_path}_{i}',zmq.PUSH))
                if self.enable_adjust_layers:
                    self.layer_adjust_socket.append(make_socket(zmq_ctx, f'{self.layer_adjust_ipc_path}_{i}',zmq.PUSH))
            if self.pp_size != 1:
                # last rank => rank 0 : next tokens
                self.token_socket = make_socket(zmq_ctx, self.token_ipc_path, zmq.PULL)
        else:
            # rank 0 => other ranks : batched seqs
            self.gpu_schedule_socket = make_socket(zmq_ctx, f'{self.schedule_ipc_path}_{self.pp_rank}', zmq.PULL) # 其他流水线阶段，负责接收调度来的请求
            if self.enable_adjust_layers:
                self.layer_adjust_socket = make_socket(zmq_ctx, f'{self.layer_adjust_ipc_path}_{self.pp_rank}', zmq.PULL) # 其他流水线阶段，负责接收调整layer的请求
            # Input data for each rank except 0 
            self.schedule_queue = deque()
            # Input data and intermediate data for rank except 0
            self.run_queue = deque()
        if self.pp_rank == self.pp_size - 1 and self.pp_size != 1:
            # last rank => rank 0 : next tokens
            self.token_socket = make_socket(zmq_ctx, self.token_ipc_path, zmq.PUSH) # 最后一个流水线阶段，还要负责发送tokens
        self.model_runner.init()
        self.dtype = self.model_runner.memory_manager.dtype
        self.hidden_size = self.model_runner.model_loader.hidden_size
        self.ret_residual = self.model_runner.model.ret_residual
        self.max_decode_seqs = self.model_runner.max_decode_seqs
        self.max_batch_tokens = self.model_runner.max_batch_tokens
        self.page_size = self.model_runner.page_size
        self.ratio_threshold_free_pages = self.model_runner.ratio_threshold_free_pages
        self.num_threshold_free_pages = int(
            self.model_runner.ratio_threshold_free_pages * self.model_runner.memory_manager.get_num_free_pages())
        
        self.mp_alive[self.pp_rank] = 1 # 成功启动对应gpu上的worker
    
    def get_num_free_pages(self):
        return self.model_runner.memory_manager.get_num_free_pages()
    

    def _adjust_layer(self,target_adjust_layers,cur_adjust_layers):
        
        cur_start_used_layer_id,cur_end_used_layer_id = get_pp_used_idx_layers(self.model_runner.model.model.num_hidden_layers,cur_adjust_layers)
        target_start_used_layer_id,target_end_used_layer_id = get_pp_used_idx_layers(self.model_runner.model.model.num_hidden_layers,target_adjust_layers)
        load_start_layer_idx,load_end_layer_idx = get_pp_load_layers(self.model_runner.model.model.num_hidden_layers)

        if cur_adjust_layers > target_adjust_layers: # 往下调整
            recv_layers_num = cur_start_used_layer_id - target_start_used_layer_id
            recv_k_future_list = []
            recv_v_future_list = []
            recv_offset = target_start_used_layer_id - load_start_layer_idx
            for i in range(0,recv_layers_num):
                target_k_cache = self.model_runner.memory_manager.segment.k_cache[recv_offset+i]
                target_v_cache = self.model_runner.memory_manager.segment.v_cache[recv_offset+i]
                recv_k_future = recv_tensor(self.get_pp_last_rank(),target_k_cache) 
                recv_v_future = recv_tensor(self.get_pp_last_rank(),target_v_cache) 
                recv_k_future_list.append(recv_k_future)
                recv_v_future_list.append(recv_v_future)                   


            send_layers_num = cur_end_used_layer_id - target_end_used_layer_id
            send_k_future_list = []
            send_v_future_list = []
            send_offset = target_end_used_layer_id - load_start_layer_idx
            for i in range(0,send_layers_num):
                send_k_cache = self.model_runner.memory_manager.segment.k_cache[send_offset+i]
                send_v_cache = self.model_runner.memory_manager.segment.v_cache[send_offset+i]
                send_k_future = send_tensor(self.get_pp_next_rank(),send_k_cache)
                send_v_future = send_tensor(self.get_pp_next_rank(),send_v_cache)
                send_k_future_list.append(send_k_future)
                send_v_future_list.append(send_v_future)

            for i in range(0,recv_layers_num):
                recv_k_future_list[i].wait()
                recv_v_future_list[i].wait()

        else: # 往上调整
            send_layers_num = target_start_used_layer_id - cur_start_used_layer_id
            send_offset = cur_start_used_layer_id - load_start_layer_idx
            send_k_cache = self.model_runner.memory_manager.segment.k_cache[send_offset:send_offset+send_layers_num]
            send_v_cache = self.model_runner.memory_manager.segment.v_cache[send_offset:send_offset+send_layers_num]
            # 向前一个gpu发送kvcache
            send_k_future_list = []
            send_v_future_list = []
            for i in range(0,send_layers_num):
                send_k_future = send_tensor(self.get_pp_last_rank(),send_k_cache[i]) 
                send_v_future = send_tensor(self.get_pp_last_rank(),send_v_cache[i])
                send_k_future_list.append(send_k_future)
                send_v_future_list.append(send_v_future)

            recv_layers_num = target_end_used_layer_id - cur_end_used_layer_id
            # 从后一个gpu获取kvcache
            recv_k_future_list = []
            recv_v_future_list = []
            recv_offset = cur_end_used_layer_id - load_start_layer_idx
            for i in range(0,recv_layers_num):
                target_k_cache = self.model_runner.memory_manager.segment.k_cache[recv_offset+i]
                target_v_cache = self.model_runner.memory_manager.segment.v_cache[recv_offset+i]
                recv_k_future = recv_tensor(self.get_pp_next_rank(),target_k_cache)
                recv_v_future = recv_tensor(self.get_pp_next_rank(),target_v_cache)
                recv_k_future_list.append(recv_k_future)
                recv_v_future_list.append(recv_v_future)
            for i in range(0,recv_layers_num):
                recv_k_future_list[i].wait()
                recv_v_future_list[i].wait()

        self.model_runner.model.adjust_layer(target_adjust_layers)


    # rank except 0 
    def run(self):
        if self.enable_adjust_layers and self.layer_adjust_socket.poll(timeout=0) != 0: # 查看有没有调整layer的信号
            recv_bytes = self.layer_adjust_socket.recv(copy=False)
            target_adjust_layers = pickle.loads(recv_bytes)
            curr_adjust_layers = self.model_runner.model.get_adjust_layers() # 目前调整的layers数量
            self._adjust_layer(target_adjust_layers,curr_adjust_layers)


        # model forward
        if len(self.run_queue) != 0:
            hidden_states = None
            residual = None
            input_data,intermediate_data = self.run_queue.popleft()
            if len(intermediate_data) == 4:
                if not (intermediate_data[0].is_completed() and intermediate_data[1].is_completed()):
                    self.run_queue.appendleft((input_data,intermediate_data))
                    return
                else:
                    hidden_states, residual = intermediate_data[2], intermediate_data[3]
            elif len(intermediate_data) == 2:
                if not intermediate_data[0].is_completed():
                    self.run_queue.appendleft((input_data,intermediate_data))
                    return
                else:
                    hidden_states = intermediate_data[1]
            else:
                assert 0
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)
            start_event.record()
            output = self.model_runner.step_once(input_data,hidden_states,residual)
            end_event.record()
            torch.cuda.synchronize()  # 等待事件完成
            elapsed_time_ms = start_event.elapsed_time(end_event) 
            logger.info(f'Worker rank {self.pp_rank} step time: {elapsed_time_ms}')
            if self.pp_rank == self.pp_size - 1:
                assert type(output) == list
                token_bytes = pickle.dumps(output)
                self.token_socket.send(token_bytes, copy=False)
            else:
                send_pp_data(output, self.get_pp_next_rank())
        
        # recv schedule seqs
        if self.gpu_schedule_socket.poll(timeout=0) != 0: # 返回非0值表示有新消息可接收
            recv_bytes = self.gpu_schedule_socket.recv(copy=False)
            seqs = pickle.loads(recv_bytes)
            self.schedule_queue.append(InputData(seqs,self.model_runner.memory_manager))
        if len(self.schedule_queue) != 0:
            # recv intermediate data
            input_data = self.schedule_queue.popleft()
            intermediate_data = recv_pp_data(
                self.get_pp_last_rank(), self.dtype, 
                [input_data.tokens.shape[0],self.hidden_size], self.ret_residual)
            self.run_queue.append((input_data,intermediate_data))
    
    # rank 0
    def get_num_decode_seqs(self):
        num_decode_seqs = len(self.seqs_to_decode) + reduce(lambda x,y: x+len(y),self.batch_running,0)
        return num_decode_seqs
    
    # rank 0
    def update_num_wait_tokens(self):
        self.num_wait_tokens = reduce(
            lambda x,y: x + len(y.token_ids) - y.computed_token_num,self.seqs_to_prefill,0)
    
    # rank 0: check if preempt seqs 
    def check_preempt(self,num_decode_tokens):
        preempt_seqs = []
        while self.model_runner.memory_manager.get_num_free_pages() < num_decode_tokens and len(self.seqs_to_decode) != 0:
            seq_to_preempt = self.seqs_to_decode.popleft()
            self.model_runner.memory_manager.free(seq_to_preempt)
            seq_to_preempt.preempt()
            preempt_seqs.append(seq_to_preempt)
            
        self.seqs_to_prefill.extendleft(preempt_seqs)
        
        self.num_preempt_seqs += len(preempt_seqs)
        if self.num_preempt_seqs - self.log_num_preempt_seqs >= 10:
            self.log_num_preempt_seqs = self.num_preempt_seqs
            logger.warning(f'#Preempted seqs: {self.num_preempt_seqs}')
            logger.warning('Try increase --ratio-free-pages or the performance is poor!')
    
    def schedule_naive(self):
        schedule_prefill_seqs = []
        schedule_decode_seqs = []
        
        num_tokens_budget = self.max_batch_tokens
        
        self.check_preempt(min(num_tokens_budget,len(self.seqs_to_decode)))
        # decode
        for _ in range(num_tokens_budget):
            if len(self.seqs_to_decode) == 0:
                break
            seq = self.seqs_to_decode.pop()
            seq.to_compute_token_num = 1
            schedule_decode_seqs.append(seq)
            
            
        self.model_runner.memory_manager.pre_allocate_page(schedule_decode_seqs)
        
        num_tokens_budget -= len(schedule_decode_seqs)
        
        # prefill
        prefill_batched_token_nums = 0
        while len(self.seqs_to_prefill) != 0 and num_tokens_budget != 0:
            seq = self.seqs_to_prefill.popleft()
            self.model_runner.memory_manager.pre_allocate_page([seq])
            if len(seq.token_ids)-seq.computed_token_num <= num_tokens_budget:
                seq.to_compute_token_num = len(seq.token_ids) - seq.computed_token_num
                prefill_batched_token_nums += seq.to_compute_token_num
                num_tokens_budget -= seq.to_compute_token_num
            else:
                prefill_batched_token_nums += num_tokens_budget
                seq.to_compute_token_num = num_tokens_budget
                num_tokens_budget = 0
            schedule_prefill_seqs.append(seq)

        if time.time()-self.log_time > 1:
            self.log_time = time.time()
            log_info = '#wait: %4d #run: %4d #prefill: %4d #decode: %4d memory_util: %5s %%' % (
                            len(self.seqs_to_prefill),
                            self.get_num_decode_seqs(),
                            prefill_batched_token_nums,
                            len(schedule_decode_seqs),
                            '%.2f' % self.model_runner.memory_manager.get_memory_util())
            if isinstance(self.model_runner.memory_manager, PrefixMemoryManager):
                log_info += ' cache_hit_rate: %5s %%' % ('%.2f' % self.model_runner.memory_manager.get_cache_hit_rate())
                logger.info(log_info)
            else:
                logger.info(log_info)
        return schedule_prefill_seqs + schedule_decode_seqs
    
    # rank 0: PP schedule 
    def schedule(self):
        
        schedule_prefill_seqs = []
        schedule_decode_seqs = []
        
        # prefill
        prefill_token_budget = self.page_size * max(self.get_num_free_pages()-self.num_threshold_free_pages,0)
        if self.pp_size > 1 and prefill_token_budget != 0:
            self.update_num_wait_tokens()
            free_ratio = self.model_runner.memory_manager.get_memory_free()
            # a = ratio_threshold_free_pages
            # free_ratio in [1,a] | prefill_ratio in [1,0]
            prefill_ratio = (free_ratio - self.ratio_threshold_free_pages) / (1-self.ratio_threshold_free_pages)
            prefill_ratio = max(prefill_ratio,0)
            prefill_token_budget = min(
                round(prefill_ratio * self.max_batch_tokens),
                prefill_token_budget)
            prefill_token_budget = min(max(self.num_wait_tokens//8,32),prefill_token_budget)
        else:
            prefill_token_budget = min(self.max_batch_tokens, prefill_token_budget)
        prefill_batched_token_nums = 0
        while len(self.seqs_to_prefill) != 0 and prefill_token_budget != 0:
            seq = self.seqs_to_prefill.popleft()
            if isinstance(self.model_runner.memory_manager, PrefixMemoryManager) and seq.computed_token_num == 0:
                self.model_runner.memory_manager.pre_allocate_computed_page([seq])
            if len(seq.token_ids)-seq.computed_token_num <= prefill_token_budget:
                seq.to_compute_token_num = len(seq.token_ids) - seq.computed_token_num
                prefill_batched_token_nums += seq.to_compute_token_num
                prefill_token_budget -= seq.to_compute_token_num
            else:
                prefill_batched_token_nums += prefill_token_budget
                seq.to_compute_token_num = prefill_token_budget
                prefill_token_budget = 0
            schedule_prefill_seqs.append(seq)

        self.model_runner.memory_manager.pre_allocate_page(schedule_prefill_seqs) # 为每个req没有computed的token分配页号写入seq.page_table
        
        # decode
        num_total_decode_seqs = self.get_num_decode_seqs()
        if num_total_decode_seqs < self.pp_size:
            decode_token_budget = num_total_decode_seqs
        else:
            # here we add num_total_decode_seqs to random.randint(0,self.pp_size-1))
            # because we want to solve the situation when #seqs=5 pp_size=4
            decode_token_budget = (num_total_decode_seqs + random.randint(0,self.pp_size-1)) // self.pp_size
        
        self.check_preempt(decode_token_budget)
        
        for _ in range(decode_token_budget):
            if len(self.seqs_to_decode) == 0:
                break
            seq = self.seqs_to_decode.popleft()
            seq.to_compute_token_num = 1
            schedule_decode_seqs.append(seq)
            
        self.model_runner.memory_manager.pre_allocate_page(schedule_decode_seqs)

        if self.pp_rank == 0:
            self.prefill_tokens = prefill_batched_token_nums
            self.decode_tokens = len(schedule_decode_seqs)
            if self.enable_adjust_layers and (schedule_prefill_seqs + schedule_decode_seqs != 0):
                forward_time = predict_worker_time(prefill_batched_token_nums, len(schedule_decode_seqs), self.model_runner.memory_manager.get_memory_util())
                sample_time = predict_sample_time(prefill_batched_token_nums, len(schedule_decode_seqs), self.model_runner.memory_manager.get_memory_util())
                logger.info(f"forward_time: {forward_time}, sample_time: {sample_time}")
                each_layer_forward_time = forward_time / self.model_runner.model.get_num_avg_layers()
                ratio = min(math.ceil(sample_time / each_layer_forward_time - 1),self.pp_size - 1)
                if ratio == self.ratio:
                    self.count += 1
                else:
                    self.count = 1
                self.ratio = ratio

        # TODO predict
        if time.time()-self.log_time > 0:
            self.log_time = time.time()
            log_info = '#wait: %4d/%8d #run: %4d #prefill: %4d #decode: %4d memory_util: %5s %%' % (
                            len(self.seqs_to_prefill),
                            self.num_wait_tokens,
                            num_total_decode_seqs,
                            prefill_batched_token_nums,
                            len(schedule_decode_seqs),
                            '%.2f' % self.model_runner.memory_manager.get_memory_util())
            if isinstance(self.model_runner.memory_manager, PrefixMemoryManager):
                log_info += ' cache_hit_rate: %5s %%' % ('%.2f' % self.model_runner.memory_manager.get_cache_hit_rate())
                logger.info(log_info)
            else:
                logger.info(log_info)
        # with open('log','a') as f:
        #     f.write(f'{prefill_batched_token_nums} {len(schedule_decode_seqs)}\n')
        return schedule_prefill_seqs + schedule_decode_seqs

    # rank 0
    def schedule_run(self):
        output = None
        # 检查是否有新的调度消息
        if self.schedule_socket.poll(timeout=0) != 0:
            # 接收并反序列化调度输出
            recv_bytes = self.schedule_socket.recv(copy=False) # 从PipeAsyncLLM发送过来的调度请求
            schedulerOutput: SchedulerOutput = pickle.loads(recv_bytes)
            # 将新的序列添加到预填充队列
            self.seqs_to_prefill.extend(schedulerOutput.schedule_lists) # 等待Prefill的请求

        if self.enable_adjust_layers:
            if not self.process_adjust_layer():
                return None

        
        if len(self.seqs_to_decode) + len(self.seqs_to_prefill) != 0 and len(self.batch_running) < self.pp_size:
            schedule_seqs = self.schedule() # 处理的请求数量

            if len(schedule_seqs) != 0:
                input_data = InputData(schedule_seqs, self.model_runner.memory_manager)
                if self.pp_size > 1:
                    seqs_bytes = pickle.dumps(schedule_seqs)
                    for i in range(1,self.pp_size):
                        self.gpu_schedule_socket[i-1].send(seqs_bytes,copy=False)
                self.batch_running.append(schedule_seqs)
                # start_event = torch.cuda.Event(enable_timing=True)
                # end_event = torch.cuda.Event(enable_timing=True)
                # start_event.record()
                output = self.model_runner.step_once(input_data)
                # torch.cuda.synchronize()
                # t2 = time.time()
                # logger.info(f'Worker rank {self.pp_rank} step time: {t2-t1}')
                
                if type(output) != list:
                    send_pp_data(output, self.get_pp_next_rank())
        
        return output
    
    def process_adjust_layer(self):
        # logger.info(f"len(seqs_to_decode):{len(self.seqs_to_decode)},len(seqs_to_prefill):{len(self.seqs_to_prefill)}")
        cur_adjust_layers = self.model_runner.model.get_adjust_layers() # 目前调整的layers数量
        target_adjust_layers = self.ratio
        if (target_adjust_layers == cur_adjust_layers) or self.count < 20: # 不用调整
            return True
        
        if len(self.batch_running) != 0:
            return False
              
        for i in range(1,self.pp_size):
            state_bytes = pickle.dumps(target_adjust_layers)
            self.layer_adjust_socket[i-1].send(state_bytes,copy=False)
        
        self._adjust_layer(target_adjust_layers,cur_adjust_layers)
        logger.info(f"Worker {self.pp_rank} adjust layer to {target_adjust_layers}")

        return True

            
        
    # rank 0
    def process_output(self, output):
        next_tokens = None
        if isinstance(output,list) : # word_size == 1
            next_tokens = output
        elif self.pp_size != 1 and self.token_socket.poll(timeout=0) != 0: # recv tokens from last rank
            recv_bytes = self.token_socket.recv(copy=False)
            next_tokens = pickle.loads(recv_bytes)
        
        if next_tokens is not None:
            schedule_seqs:List[Sequence] = self.batch_running.popleft()
            assert len(next_tokens) == len(schedule_seqs)

            send_tokens = []
            schedulerOutput = SchedulerOutput([])
            
            for idx, seq in enumerate(schedule_seqs):
                seq.computed_token_num += seq.to_compute_token_num
                if seq.computed_prompt():
                    schedulerOutput.act_schedule_ids.append(seq.seq_id)
                    send_tokens.append(next_tokens[idx])
                    seq.token_ids.append(next_tokens[idx])
                if seq.is_finish():
                    schedulerOutput.free_ids.append(seq.seq_id)
                    self.model_runner.memory_manager.free(seq)
                elif seq.computed_prompt():
                    self.seqs_to_decode.appendleft(seq)
                else:
                    self.seqs_to_prefill.appendleft(seq)
            
            output_bytes = pickle.dumps((schedulerOutput, send_tokens))
            self.output_socket.send(output_bytes, copy=False)
 
def run_worker(worker: Worker):
    try:
        worker.init()
        logger.info(f'Worker {worker.pp_rank} init')
        while True:
            if worker.pp_rank == 0:
                output = worker.schedule_run()
                worker.process_output(output)   
            else:
                worker.run()
    except KeyboardInterrupt as e:
        logger.info(f'Worker {worker.pp_rank} exit')
        dist.destroy_process_group()
        worker.mp_alive[worker.pp_rank] = -1
    except Exception as e:
        logger.error(f'Worker {worker.pp_rank} \n{e}')
        traceback.print_exc()
        dist.destroy_process_group()
        worker.mp_alive[worker.pp_rank] = -1