import json
import torch
import torch.nn.functional as F

def read_text_from_file(file_path):
    """
    Load lines of texts.

    Args:
        file_path (str): Path for lines of texts.

    Returns:
        (List[str]): List of texts.
    """
    with open(file_path, "r", encoding="utf-8") as file:
        lines = file.readlines()
    return [line.strip() for line in lines]


def read_jsonl_file(file_path):
    """
    Load lines of texts.

    Args:
        file_path (str): Path for lines of texts.

    Returns:
        (List[str]): List of texts.
    """
    data = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            json_data = json.loads(line.strip())
            data.append(json_data)
    return data


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(k, cos, sin, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return k_embed

def remove_rotary_pos_emb(k_embed, cos, sin, unsqueeze_dim=1):
    """
    Removes rotary position embeddings from the key tensor.

    Args:
        k_embed (torch.Tensor): Key tensor with position encoding applied.
        cos (torch.Tensor): The cosine part of the rotary embedding.
        sin (torch.Tensor): The sine part of the rotary embedding.
        unsqueeze_dim (int, optional): Dimension to unsqueeze cos and sin. Defaults to 1.

    Returns:
        torch.Tensor: Original key tensor without position encoding.
    """
    # Prepare cos and sin for broadcasting
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    k_rot = rotate_half(k_embed)
    k = (k_embed * cos) - (k_rot * sin)
    return k

def concate_past_key_value(past_key_values_list: list):
    concatenated_past_key_values = []
    num_layers = len(past_key_values_list[0])
    for layer_idx in range(num_layers):
        layer_keys = [past[layer_idx][0] for past in past_key_values_list]
        layer_values = [past[layer_idx][1] for past in past_key_values_list]
        concatenated_layer_key = torch.cat(layer_keys, dim=2)
        concatenated_layer_value = torch.cat(layer_values, dim=2)

        concatenated_past_key_values.append((concatenated_layer_key, concatenated_layer_value))
    final_past_key_values = tuple(concatenated_past_key_values)

    return final_past_key_values
