import torch
import torch.nn as nn
import math
from functools import lru_cache
from typing import Tuple

def sinusoidal_positional_embedding1d(d_model, max_seq_len=5000, theta=10000.0):
    position = torch.arange(max_seq_len).unsqueeze(1) # n 
    div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(theta) / d_model))
    pe = torch.zeros(max_seq_len, d_model) # n x 1 x d
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    
    return pe # n x d


def sinusoidal_positional_embedding2d(d_model, max_height=100, max_width=100, theta=10000.0):
    position_h = torch.arange(max_height).unsqueeze(1)
    div_term_h = torch.exp(torch.arange(0, d_model // 2, 2) * -(math.log(theta) / (d_model // 2)))
    pe_h = torch.zeros(max_height, 1, d_model // 2)
    pe_h[:, 0, 0::2] = torch.sin(position_h * div_term_h)
    pe_h[:, 0, 1::2] = torch.cos(position_h * div_term_h)

    position_w = torch.arange(max_width).unsqueeze(1)
    div_term_w = torch.exp(torch.arange(0, d_model // 2, 2) * -(math.log(theta) / (d_model // 2)))
    pe_w = torch.zeros(1, max_width, d_model // 2)
    pe_w[0, :, 0::2] = torch.sin(position_w * div_term_w) 
    pe_w[0, :, 1::2] = torch.cos(position_w * div_term_w)
  
    pe = torch.cat((pe_h.repeat(1, max_width, 1), pe_w.repeat(max_height, 1, 1)), dim=2)

    return pe # h x w x d


@lru_cache
def precompute_freqs_cis(dim: int, end: int = 128_000, theta: float=10000.0) -> torch.Tensor:
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = freqs_cis[..., None, :]
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)
    return xq_out.type_as(xq), xk_out.type_as(xk)
