import torch
import time
from itertools import chain

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def masked_mini_filled_tensor(bool_mask, ref_dtype):
    result = torch.full(bool_mask.size(), fill_value=torch.finfo(ref_dtype).min, dtype=ref_dtype, device=bool_mask.device)
    result[bool_mask] = 0.
    return result

def apply_rotary_pos_emb_(x, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`):
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
            used to pass offsetted position ids when working with a KV-cache.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    sin = sin.unsqueeze(unsqueeze_dim)
    cos = cos.unsqueeze(unsqueeze_dim)
    embed = (x * cos) + (rotate_half(x) * sin)
    return embed

def last_token_relevance(scores):
    ones_vector = torch.ones(scores.size(-1), 1, device=scores.device, dtype=scores.dtype)
    scores = torch.matmul(scores, ones_vector)
    score = scores[:,-1, :].unsqueeze(-1)
    return score

def topn_mean(scores, n):
    ones_vector = torch.ones(scores.size(-1), 1, device=scores.device, dtype=scores.dtype)
    scores = torch.matmul(scores, ones_vector)
    topn_values, _ = torch.topk(scores.squeeze(-1), k=n, dim=1)
    means = topn_values.mean(dim=1, keepdim=True)
    score = means.unsqueeze(-1)
    return score

def masked_topk(attn_mx, n, mask=None, mx_reduction='mean', topk_reduction='mean'):
    if mx_reduction == 'sum': attn_mx = attn_mx.sum(-1)
    else: attn_mx = attn_mx.mean(-1)
    #
    if mask is not None:
        if n == -1: attn_mx[mask] = 0.
        else: attn_mx[mask] = torch.finfo(torch.float32).min
    if n == -1: n = attn_mx.size(1)
    if topk_reduction == 'sum': topk_value = torch.topk(attn_mx, k=n, dim=1).values.sum(-1, keepdim=True)
    else: topk_value = torch.topk(attn_mx, k=n, dim=1).values.mean(-1, keepdim=True)
    return topk_value.unsqueeze(-1)

def now(format='%Y-%m-%d-%H:%M:%S'): return time.strftime(format, time.localtime())

def flatten(data:list): return list(chain.from_iterable(data))
