import torch
import math

def rotary_embedding(dim, max_seq_len):
    """
    Generate rotary positional encodings.
    :param dim: Embedding dimension.
    :param max_seq_len: Maximum sequence length.
    :return: Rotation matrices.
    """
    base = 10000
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    positions = torch.arange(max_seq_len).float()
    sinusoid_inp = torch.einsum('i,j->ij', positions, inv_freq)
    sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos()
    return sin, cos

def rotary_embedding_ntk(dim, max_seq_len, scale = 8):
    """
    Generate rotary positional encodings.
    :param dim: Embedding dimension.
    :param max_seq_len: Maximum sequence length.
    :return: Rotation matrices.
    """
    max_position_embeddings = max_seq_len * scale
    max_seq_len = max_position_embeddings
    base = 10000
    base = base * scale ** (dim / (dim-2)) #Base change formula
    
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    positions = torch.arange(max_seq_len).float()
    sinusoid_inp = torch.einsum('i,j->ij', positions, inv_freq)
    sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos()
    return sin, cos

def rotary_embedding_linear(dim, max_seq_len, scale = 8):
    """
    Generate rotary positional encodings.
    :param dim: Embedding dimension.
    :param max_seq_len: Maximum sequence length.
    :return: Rotation matrices.
    """
    max_position_embeddings = max_seq_len * scale
    max_seq_len = max_position_embeddings
    base = 10000
    
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    positions = torch.arange(max_seq_len).float()
    positions /= scale
    sinusoid_inp = torch.einsum('i,j->ij', positions, inv_freq)
    sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos()
    return sin, cos

def apply_rope(x, sin, cos):
    """
    Apply ROPE to input embeddings.
    :param x: Input tensor of shape (batch_size, seq_len, dim).
    :param sin, cos: Rotation matrices.
    :return: Transformed tensor with positional encoding applied.
"""
    x1, x2 = x[..., ::2], x[..., 1::2]  # Split into even and odd dimensions
    x_out = torch.zeros_like(x)
    x_out[..., ::2] = x1 * cos - x2 * sin  # Real part
    x_out[..., 1::2] = x1 * sin + x2 * cos  # Imaginary part
    return x_out

# Example Usage
seq_len, dim = 80, 50  # Sample dimensions
x = torch.randn(seq_len, dim)  # Input embeddings

# Generate ROPE
max_seq_len = 80
sin, cos = rotary_embedding(dim, max_seq_len)
x_rope = apply_rope(x, sin[:seq_len, :], cos[:seq_len, :])

# Generate ROPE with linear interpolation
max_seq_len = 10
sin, cos = rotary_embedding_linear(dim, max_seq_len, scale=8)
x_rope_linear = apply_rope(x, sin[:seq_len, :], cos[:seq_len, :])

# Generate ROPE with ntk interpolation
max_seq_len = 10
sin, cos = rotary_embedding_ntk(dim, max_seq_len, scale=8)
x_rope_ntk = apply_rope(x, sin[:seq_len, :], cos[:seq_len, :])

print("Input:", x)
print("ROPE-normal:", x_rope)
print("ROPE-linear:", x_rope_linear)
print("ROPE-ntk:", x_rope_ntk)
# print("difference:", x - x_rope)

cosim = torch.nn.CosineSimilarity(dim=0, eps=1e-6)
output = cosim(x_rope, x_rope_linear)
print(f"cosin smilarity per channel of ROPE-normal vs linear:", output)

output = cosim(x_rope, x_rope_ntk)
print(f"cosin smilarity per channel of ROPE-normal vs ntk:", output)

output = cosim(x_rope_linear, x_rope_ntk)
print(f"cosin smilarity per channel of ROPE-linear vs ntk:", output)