import torch

from sentence_transformers import SentenceTransformer, util
import numpy as np
from typing import List

from verl.workers.rollout.vllm_rollout.llm import LLM
from verl.workers.rollout.vllm_rollout.pref_instance_cls import SAGEParams as BeamSearchParams, OutputObject, SAGESequence as BeamSearchSequence, SAGEOutput as BeamSearchOutput


def generate_uniform_indicator_tensor(tensor):
    # 1. 计算每行的和（形状：(B,)）
    row_sum = tensor.sum(dim=1)
    S = tensor.shape[1]  # 每行的长度
    
    # 2. 标记需要置 0 的行：sum=0（全0）或 sum=S（全1）→ True，否则 False（形状：(B,)）
    uniform_rows = (row_sum == 0) | (row_sum == S)
    
    # 3. 生成指示标记：uniform_rows为True→0，否则→1（形状：(B,)）
    row_indicator = (~uniform_rows)
    
    # 4. 扩展为 (B, S) 形状：每行的所有元素都等于该行的指示标记
    indicator_tensor = row_indicator.unsqueeze(1).expand_as(tensor).flatten()
    
    return indicator_tensor


def transform_AAA_to_A(tensor, n):
    # 关键逻辑：取每组n个元素的第一个（索引为 0, n, 2n, ..., (batch-1)×n）
    batch_size = tensor.shape[0] // n  # 计算原始 batch 大小
    indices = torch.arange(0, tensor.shape[0], step=n, device=tensor.device)  # 生成取数索引
    compressed_tensor = tensor[indices, :]  # 按索引取数，压缩为 [batch, seq]
    return compressed_tensor


def fetch_single_prompt_list(lst, n):   
    # 计算原始 batch 大小
    batch_size = len(lst) // n
    
    # 取每组第1个元素（索引为 0, n, 2n, ..., (batch_size-1)*n）
    compressed_lst = [lst[i * n] for i in range(batch_size)]
    return compressed_lst


def fetch_group_first_x(lst, n, x):   
    batch_size = len(lst) // n
    # 每组取前x个元素，合并为新列表（双层推导：先遍历组，再遍历组内前x个元素）
    compressed_lst = [lst[i * n + j] for i in range(batch_size) for j in range(x)]
    return compressed_lst


def fetch_group_first_elements(lst, n, repeat_times):   
    batch_size = len(lst) // n
    # 每组取前x个元素，合并为新列表（双层推导：先遍历组，再遍历组内前x个元素）
    compressed_lst=[]
    for i in range(batch_size):
        for j in range(repeat_times[i]):
            compressed_lst.append(lst[i * n + j])
    # compressed_lst = [lst[i * n + j] for i in range(batch_size) for j in range(x)]
    return compressed_lst



def compute_matrix_average(sim_matrix, include_diagonal=False):
    """
    计算相似度矩阵的平均值
    
    Args:
        sim_matrix (torch.Tensor): 语义相似度矩阵，shape为 [n, n]
        include_diagonal (bool): 是否包含对角线元素（句子与自身的相似度）
    
    Returns:
        float: 矩阵的平均值
    """
    sim_matrix=torch.tensor(sim_matrix)
    if include_diagonal:
        # 计算所有元素的平均值
        return sim_matrix.mean().item()
    else:
        # 排除对角线元素：生成掩码，对角线为False，其他为True
        n = sim_matrix.shape[0]
        mask = ~torch.eye(n, dtype=torch.bool, device=sim_matrix.device)
        # 提取非对角线元素并计算平均值
        return sim_matrix[mask].mean().item()


def compute_semantic_similarity(texts: List[str], model_name: str = "all-MiniLM-L6-v2") -> np.ndarray:
    """
    计算文本列表中所有文本两两之间的语义相似度（余弦相似度）
    
    参数:
        texts: 输入的文本列表（至少包含1个文本）
        model_name: Sentence-BERT模型名称（默认使用轻量级模型all-MiniLM-L6-v2）
    
    返回:
        similarity_matrix: 二维numpy数组，shape为[len(texts), len(texts)]，
                          similarity_matrix[i][j]表示texts[i]与texts[j]的相似度
    """
    # 初始化模型（首次运行会自动下载模型，约80MB）
    model = SentenceTransformer(model_name)
    
    # 对所有文本进行编码，生成语义向量（shape: [n_texts, embedding_dim]）
    embeddings = model.encode(texts, convert_to_tensor=True)  # 转换为PyTorch张量以加速计算
    
    # 计算所有文本对的余弦相似度（返回shape为[n_texts, n_texts]的矩阵）
    similarity_matrix = util.cos_sim(embeddings, embeddings).cpu().numpy()  # 转为numpy数组
    
    return similarity_matrix


# n = batch_num, 和最终的n不是同一个
# add lowest accuracy. if acc lower than the threshold, set to the 
def get_adaptive_merge_ratio(accuracy_list, acc_threshold=0.75):

    acc_array = np.array(accuracy_list)
    n = len(acc_array)
    if n == 0:
        return []
    
    # 计算排名和百分比（同之前逻辑）
    sorted_indices = np.argsort(-acc_array)
    ranks = np.empty_like(sorted_indices)
    ranks[sorted_indices] = np.arange(n)
    percentiles = (ranks + 1) / n  # 0-1之间的排名百分比

    # set acc_threshold
    for i in range(len(accuracy_list)):
        if accuracy_list[i] < acc_threshold:
            percentiles[i] = 1
    return percentiles.tolist()
    


def create_attention_mask(lengths, max_seq_len=10):
    B = len(lengths)
    # 确定最大最大序列长度（若未指定则取列表中的最大值）
    if max_seq_len is None:
        max_seq_len = max(lengths)
    
    # 生成形状为 [B, max_seq_len] 的全范围张量（0 ~ max_seq_len-1）
    range_tensor = torch.arange(max_seq_len).unsqueeze(0).repeat(B, 1)
    
    # 生成长度的广播张量（形状 [B, 1] → 广播为 [B, max_seq_len]）
    lengths_tensor = torch.tensor(lengths).unsqueeze(1)
    
    # 有效位置：range < lengths → mask=1；否则为0
    attention_mask = (range_tensor < lengths_tensor)
    
    return attention_mask


def convert_generate_to_beam_search(vllm_outputs, k, max_len):
    #vllm_outputs->divide into groups
    beam_search_outputs= []
    current_beam = []
    for output in vllm_outputs:
        seq = output.outputs[0]
        seq.tokens = output.prompt_token_ids + seq.token_ids[:max_len]
        seq.logprobs = output.prompt_logprobs + seq.logprobs[:max_len]
        seq.prompt_length = len(output.prompt_token_ids)+1 # fix the bug
                                    # 将当前序列加入临时束
        current_beam.append(seq)
        
        # 当临时束的序列数量达到k时，封装为束搜索输出并重置临时束
        if len(current_beam) == k:
            beam_search_output = BeamSearchOutput(sequences=current_beam)
            beam_search_outputs.append(beam_search_output)
            
            current_beam = []  # 重置临时束，准备下一组
    
    return beam_search_outputs

# print(create_attention_mask([1,2,3,10]))
# print(create_attention_mask([1,2,3,4]).shape)

