import torch
import torch.nn as nn
from causal_conv1d import causal_conv1d_fn
from einops import rearrange
from fla.modules.rotary import ChunkedLinear  # type: ignore
import numpy as np


def precompute_selective_rope_weights(
    dim: int, theta: float = 500000.0, epsilon: float = 0.995
):
    return 1.0 / (
        theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)
    ).reshape(1, 1, 1, -1)


class OGSelectiveRoPE(nn.Module):
    def __init__(
        self,
        num_heads: int,
        d_conv: int = 4,
        d_state: int = 128,
        skip_conv_cumsum: bool = False,
        device: torch.device | str = "cuda",
        dtype: torch.dtype = torch.bfloat16
    ):
        super().__init__()
        self.num_heads = num_heads
        self.skip_conv_cumsum = skip_conv_cumsum
        pe_dim = d_state // 2
        self.phi_bias = nn.Parameter(
            torch.ones(1, 1, self.num_heads, pe_dim).float(), requires_grad=True
        )
        
        self.phi_proj = ChunkedLinear(
            2 * pe_dim,
            pe_dim,
            num_heads=num_heads,
            bias=False,
            random_init=True,
            rank=-1,
        )
        
        if not skip_conv_cumsum:
            self.phi_conv1d = nn.Conv1d(
                in_channels=self.num_heads * pe_dim,
                out_channels=self.num_heads * pe_dim,
                bias=False,
                kernel_size=d_conv,
                groups=self.num_heads * pe_dim,
                padding=d_conv - 1,
                device=device,
                dtype=dtype,
            )

        self.temperature = nn.Parameter(
            precompute_selective_rope_weights(2 * pe_dim),
            requires_grad=False,
        )

    def forward(self, q: torch.Tensor, k: torch.Tensor):
        coeff = 1.0
        phi = rearrange(
            self.phi_proj(rearrange(q, "b t h d -> (b t) h d")),
            "(b t) h d -> b (h d) t",
            b=q.shape[0],
        )
        
        if not self.skip_conv_cumsum:
            phi = rearrange(
                causal_conv1d_fn(
                    phi,
                    rearrange(self.phi_conv1d.weight, "d 1 w -> d w"),
                    activation=None,
                ),
                "b (h d) t -> b t h d",
                h=self.num_heads,
            )
        else: 
            phi = rearrange(phi, "b (h d) t -> b t h d", h=self.num_heads)
        
        phi = phi + self.phi_bias
        
        if self.skip_conv_cumsum:
            phi_tilde = phi
        else:
            phi_tilde = torch.cumsum(phi, dim=1)

        qk_phi_tilde = torch.cat([phi_tilde, phi_tilde], dim=2)

        qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float()
        rotated_qk = torch.stack(
            [
                qk_r2[..., 0] * torch.cos(coeff * self.temperature * qk_phi_tilde)
                - qk_r2[..., 1] * torch.sin(coeff * self.temperature * qk_phi_tilde),
                qk_r2[..., 1] * torch.cos(coeff * self.temperature * qk_phi_tilde)
                + qk_r2[..., 0] * torch.sin(coeff * self.temperature * qk_phi_tilde),
            ],
            -1,
        ).flatten(3)

        return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2)
