import torch
import math
import numpy as np

# Inverse dim formula to find dim based on number of rotations
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
    return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base))

# Find dim range bounds based on rotations
# low, high = find_correction_range(self.beta_fast=32, self.beta_slow=1, self.dim, self.base, self.original_max_position_embeddings)
def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
    low = math.floor(find_correction_dim(
        low_rot, dim, base, max_position_embeddings))
    high = math.ceil(find_correction_dim(
        high_rot, dim, base, max_position_embeddings))
    return max(low, 0), min(high, dim-1)  # Clamp values just in case

def linear_ramp_mask(min, max, dim):
    if min == max:
        max += 0.001  # Prevent singularity

    linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
    ramp_func = torch.clamp(linear_func, 0, 1)
    return ramp_func

def linear_ramp_mask2(beta, alpha, base, dim, max_position_embeddings):
    l1 = max_position_embeddings/(2*math.pi)
    dims_array = np.arange(dim) / dim

    rd = (l1/(base**i) for i in dims_array)
    rounds = torch.tensor(list(rd)).float()

    linear_func = (rounds - alpha) / (beta - alpha)
    ramp_func = torch.clamp(linear_func, 0, 1)
    for i in range(ramp_func.shape[0]):
        if ramp_func[i] != 1:
            low = i-1
            break

    for i in range(ramp_func.shape[0]):
        if ramp_func[i] == 0:
            high = i
            break

    print(f"low: {low}, high: {high}, dim: {dim}")
    print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
    return ramp_func

def get_mscale(scale=1):
    if scale <= 1:
        return 1.0
    return 0.1 * math.log(scale) + 1.0

class LlamaYaRNScaledRotaryEmbedding(torch.nn.Module):
    # dim= k_div = q_div = v_div
    def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, original_max_position_embeddings=2048, extrapolation_factor=1, attn_factor=1, beta_fast=32, beta_slow=1, finetuned=False, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        self.scale = scale
        self.original_max_position_embeddings = original_max_position_embeddings
        self.extrapolation_factor = extrapolation_factor
        self.attn_factor = attn_factor
        self.beta_fast = beta_fast
        self.beta_slow = beta_slow
        
        
        self.yarn(device)
        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        self.original_inv_freq = self.inv_freq
        self.attention_scaling = self.mscale

        
    def forward(self, x, position_ids):
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
        position_ids_expanded = position_ids[:, None, :].float()

        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):  # Force float32
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos() * self.attention_scaling
            sin = emb.sin() * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)



    def yarn(self, device):
        pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
        inv_freq_extrapolation = 1.0 / pos_freqs
        inv_freq_interpolation = 1.0 / (self.scale * pos_freqs)
        print(">>>>>>>>using yarn() method<<<<<<<<<<")
        inv_freq_mask = linear_ramp_mask2(self.beta_fast, self.beta_slow, self.base, self.dim // 2, self.original_max_position_embeddings).float().to(device) * self.extrapolation_factor
        inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
        self.register_buffer("inv_freq", inv_freq)
        self.mscale = float(get_mscale(self.scale) * self.attn_factor) # Get n-d magnitude 