import os
import torch.distributed as dist
import torch
def send_pp_data(output, dst):
    # 如果输入是元组类型
    if type(output) == tuple:
        # 确保元组只包含两个元素
        assert len(output) == 2
        # 异步发送第一个元素
        dist.isend(output[0], dst)
        # 异步发送第二个元素
        dist.isend(output[1], dst)
    else:
        # 如果不是元组，直接发送整个输出
        dist.isend(output, dst)

def recv_pp_data(src, dtype, shape, has_residual):
    hidden_states = torch.zeros(torch.Size(shape),dtype=dtype,device=f'cuda:{dist.get_rank()}')
    if has_residual:
        residual = hidden_states.clone().detach()
        hidden_states_future = dist.irecv(hidden_states,src)
        residual_future = dist.irecv(residual,src)
        return hidden_states_future, residual_future, hidden_states, residual
    else:
        hidden_states_future = dist.irecv(hidden_states,src)
        return hidden_states_future, hidden_states
    
# 异步传送kv cache的函数
def send_tensor(dst, tensor):
    # 发送kv_cache
    future = dist.isend(tensor, dst)
    return future

def recv_tensor(src, target_tensor):
    future = dist.irecv(target_tensor,src)
    return future


_PP_RANK=0
_PP_SIZE=1
_PP_ADJUST_LAYERS=0

def get_pp_rank():
    return _PP_RANK

def get_pp_size():
    return _PP_SIZE

def get_pp_last_rank():
    # return device_rank of last pp_rank
    return (_PP_RANK - 1 + _PP_SIZE) % _PP_SIZE

def set_pp_adjust_layers(adjust_layers):
    """
    设置层数调整值
    Args:
        adjust_layers (int): 0-平均分配,(1,pp_size-1)表示调整的层数
    """
    global _PP_ADJUST_LAYERS
    assert 0 <= adjust_layers <= _PP_SIZE-1, "调整层数必须在(0,pp_size-1)之间"
    _PP_ADJUST_LAYERS = adjust_layers

def get_pp_adjust_layers():
    return _PP_ADJUST_LAYERS

def init_dist(pp_size, pp_rank, device_size, device_rank, master_addr, master_port):
    global _PP_RANK, _PP_SIZE
    _PP_RANK = pp_rank
    _PP_SIZE = pp_size
    os.environ['MASTER_ADDR'] = master_addr
    os.environ['MASTER_PORT'] = master_port
    dist.init_process_group(backend='nccl', world_size=device_size, rank=device_rank)

def get_pp_load_layers_native(num_layers):

    assert num_layers % get_pp_size() == 0
    num_layers_pp = num_layers // get_pp_size()

    # 0 16 16 16 16
    # 1 16 16 17 15
    # 2 16 17 17 14
    # 3 17 17 17 13
    adjust_layers = 0 # 决定初始化怎么分配
    rank = get_pp_rank()
    pp_size = get_pp_size()
    if rank < pp_size - adjust_layers - 1:
        return num_layers_pp * rank,num_layers_pp * (rank+1)
    else:
        return num_layers_pp * rank + adjust_layers+rank+1-pp_size , num_layers_pp * (rank+1) + (adjust_layers+rank+2-pp_size)%(adjust_layers+1)
 

def get_pp_load_layers(num_layers):

    assert num_layers % get_pp_size() == 0
    num_layers_pp = num_layers // get_pp_size()
    rank = get_pp_rank()
    # TODO 后面需要实现迁移而非提前分配
    # 预先多分配几层，后面动态变化可以用
    if rank != get_pp_size() - 1:
        return num_layers_pp * rank,num_layers_pp * (rank+1)+(rank+1)%4
    else:
        return num_layers_pp * rank,num_layers




def get_pp_used_layers(num_layers):

    assert num_layers % get_pp_size() == 0
    num_layers_pp = num_layers // get_pp_size()
    # TODO 后面需要实现迁移而非提前分配
    # 预先多分配几层，后面动态变化可以用
    adjust_layers = get_pp_adjust_layers()
    rank = get_pp_rank()
    pp_size = get_pp_size()
    if rank < pp_size - adjust_layers - 1:
        return num_layers_pp * rank,num_layers_pp * (rank+1)
    else:
        return num_layers_pp * rank + adjust_layers+rank+1-pp_size , num_layers_pp * (rank+1) + (adjust_layers+rank+2-pp_size)%(adjust_layers+1)
    

    

def get_pp_used_idx_layers(num_layers,adjust_layers):
    assert num_layers % get_pp_size() == 0
    num_layers_pp = num_layers // get_pp_size()
    rank = get_pp_rank()
    pp_size = get_pp_size()
    if rank < pp_size - adjust_layers - 1:
        return num_layers_pp * rank,num_layers_pp * (rank+1)
    else:
        return num_layers_pp * rank + adjust_layers+rank+1-pp_size , num_layers_pp * (rank+1) + (adjust_layers+rank+2-pp_size)%(adjust_layers+1)  

  

def get_pp_load_num_layers(num_layers):
    assert num_layers % get_pp_size() == 0
    num_layers_pp = num_layers // get_pp_size()
    # TODO 后面需要实现迁移而非提前分配
    # 预先多分配几层，后面动态变化可以用
    return num_layers_pp+(get_pp_rank()+1)%4


def get_pp_used_num_layers(num_layers):
    assert num_layers % get_pp_size() == 0
    num_layers_pp = num_layers // get_pp_size()
    # TODO 后面需要实现迁移而非提前分配
    # 预先多分配几层，后面动态变化可以用
    adjust_layers = get_pp_adjust_layers()
    rank = get_pp_rank()
    pp_size = get_pp_size()
    if rank < pp_size - adjust_layers - 1:
        return num_layers_pp
    elif rank != pp_size - 1:
        return num_layers_pp + 1
    else:
        return num_layers_pp - adjust_layers
        


#TODO 想保留着，给其他模型用，防止出错，后面改完可以删掉
def get_pp_layers(num_layers):
    assert num_layers % get_pp_size() == 0
    # elapsed time for last stage ususally longer, so we add 1 
    num_layers_pp = num_layers // get_pp_size() + 1
    if get_pp_rank() != get_pp_size() - 1:
        return num_layers_pp * get_pp_rank(), num_layers_pp * (get_pp_rank()+1)
    else:
        return num_layers_pp * get_pp_rank(), num_layers

def get_pp_num_layers(num_layers):
    assert num_layers % get_pp_size() == 0
    # elapsed time for last stage ususally longer, so we add 1
    num_layers_pp = num_layers // get_pp_size() + 1
    if get_pp_rank() != get_pp_size() - 1:
        return num_layers_pp
    else:
        return num_layers_pp - get_pp_size()
    