from typing import Optional
import torch
import torch.nn as nn
from fla.modules import ShortConvolution
from fla.modules.l2norm import l2_norm
from fla.modules.rotary import ChunkedLinear, rotary_temperature
from einops import rearrange


class SelectiveRoPENew(nn.Module):
    def __init__(
        self,
        head_dim: int,
        num_heads: int = 1,
        skip_conv: bool = False,
        dtype: Optional[torch.dtype] = None,
        d_conv: int = 4,
        temp_type: str = "rope",
        temp_theta: float = 500000,
        temp_max: float = 1.0,
        temp_grad: bool = False,
        is_softmax: bool = False,
        use_low_rank_phi_proj: bool = False,
        phi_proj_rank: int = 32,
        phi_conv_activation: str | None = None,
        partial_rope_ratio: float = 1.0,
        pi_proj: bool = False,
    ):
        super().__init__()
        assert partial_rope_ratio <= 1.0 and partial_rope_ratio > 0.0

        self.head_dim = head_dim
        self.num_heads = num_heads
        self.skip_conv = skip_conv
        self.is_softmax = is_softmax
        self.use_low_rank_phi_proj = use_low_rank_phi_proj
        self.phi_proj_rank = phi_proj_rank
        self.phi_conv_activation = phi_conv_activation
        self.partial_rope_ratio = partial_rope_ratio
        self.pi_proj = pi_proj
        # Project from model hidden size (num_heads * head_dim) to itself
        pe_dim = head_dim // 2
        self.phi_bias = nn.Parameter(
            torch.zeros(1, 1, num_heads, int(self.partial_rope_ratio * pe_dim)).float(),
            requires_grad=True,
        )

        if self.use_low_rank_phi_proj:
            self.phi_proj = nn.Sequential(
                nn.Linear(num_heads * head_dim, phi_proj_rank, bias=False),
                nn.Linear(phi_proj_rank, num_heads * pe_dim, bias=False),
            )
        else:
            self.phi_proj = ChunkedLinear(
                2 * pe_dim,
                int(self.partial_rope_ratio * pe_dim),
                num_heads=num_heads,
                bias=False,
                random_init=True,
                rank=-1,
            )

        if not skip_conv:
            self.phi_conv1d = ShortConvolution(
                hidden_size=int(self.partial_rope_ratio * num_heads * pe_dim),
                kernel_size=d_conv,
                bias=False,
                activation=phi_conv_activation,
                dtype=dtype,
            )

        self.temperature = nn.Parameter(
            rotary_temperature(
                temp_type, temp_theta, int(self.partial_rope_ratio * head_dim), temp_max
            ).reshape(1, 1, 1, -1),
            requires_grad=temp_grad,
        )

    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        inputs: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, ...]:
        if self.is_softmax:
            q_normed = l2_norm(q)
            # k_normed = l2_norm(k, dim=-1)

        if self.use_low_rank_phi_proj:
            if not self.pi_proj:
                phi = rearrange(
                    l2_norm(self.phi_proj(inputs)),
                    "b t (h d) -> b (h d) t",
                    h=self.num_heads,
                )
            else:
                weights_norm = (
                    self.phi_proj[0].weight.T @ self.phi_proj[1].weight.T
                ).norm(dim=0)
                weights_norm = rearrange(weights_norm, "(h d) -> d h", h=self.num_heads)
                # import pdb; pdb.set_trace()

                inputs = l2_norm(inputs)
                inputs = rearrange(inputs, "b t h d -> b t (d h)", h=self.num_heads)

                # import pdb; pdb.set_trace()
                phi = torch.pi * (
                    rearrange(
                        self.phi_proj(inputs), "b t (d h) -> b t d h", h=self.num_heads
                    )
                    / weights_norm
                )  # b x T x d x h / d x h
                phi = rearrange(phi, "b t d h -> b (h d) t", h=self.num_heads)
        else:
            phi = rearrange(
                self.phi_proj(
                    rearrange(
                        q_normed if self.is_softmax else q, "b t h d -> (b t) h d"
                    )
                ),
                "(b t) h d -> b (h d) t",
                b=q.shape[0],
            )

        # if self.pi_proj and self.use_low_rank_phi_proj:
        #     phi = self.phi_proj.weight

        if not self.skip_conv:
            # Apply ShortConvolution with proper reshaping
            # ShortConvolution expects [B, T, D] so reshape from [B, D, T]
            phi, _ = self.phi_conv1d(rearrange(phi, "b d t -> b t d"))
            # Reshape back to [B, T, H, N]

            phi = rearrange(
                phi,
                "b t (h d) -> b t h d",
                h=self.num_heads,
            )
        else:
            # Hard shift, with zero padding
            # Accepts b d t, and shifts along the last dim
            phi = phi - torch.cat(
                [torch.zeros_like(phi[..., :1]), phi[..., :-1]], dim=-1
            )
            phi = rearrange(phi, "b (h d) t -> b t h d", h=self.num_heads)

        phi = phi + torch.exp(self.phi_bias)
        # phi = phi
        # Cumulative sum along the sequence dimension
        phi_tilde = torch.cumsum(phi, dim=1)

        qk_phi_tilde = torch.cat([phi_tilde, phi_tilde], dim=2)

        if self.partial_rope_ratio < 1.0:
            q_l, q_r = q.split(
                [
                    int(self.partial_rope_ratio * self.head_dim),
                    self.head_dim - int(self.partial_rope_ratio * self.head_dim),
                ],
                dim=-1,
            )
            q = q_l
            k_l, k_r = k.split(
                [
                    int(self.partial_rope_ratio * self.head_dim),
                    self.head_dim - int(self.partial_rope_ratio * self.head_dim),
                ],
                dim=-1,
            )
            k = k_l

        qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float()
        rotated_qk = torch.stack(
            [
                qk_r2[..., 0] * torch.cos(self.temperature * qk_phi_tilde)
                - qk_r2[..., 1] * torch.sin(self.temperature * qk_phi_tilde),
                qk_r2[..., 1] * torch.cos(self.temperature * qk_phi_tilde)
                + qk_r2[..., 0] * torch.sin(self.temperature * qk_phi_tilde),
            ],
            -1,
        ).flatten(3)

        q, k = torch.split(rotated_qk.type_as(q), q.shape[2], dim=2)

        if self.partial_rope_ratio < 1.0:
            q = torch.cat([q, q_r], dim=-1)
            k = torch.cat([k, k_r], dim=-1)

        return q, k


if __name__ == "__main__":
    rope = SelectiveRoPENew(
        head_dim=64,
        num_heads=7,
        skip_conv=False,
        dtype=torch.bfloat16,
        d_conv=4,
        temp_type="rope",
        temp_theta=500000.0,
        temp_max=1.0,
        temp_grad=False,
        is_softmax=False,
        use_low_rank_phi_proj=True,
        phi_proj_rank=32,
        phi_conv_activation=None,
        partial_rope_ratio=1,
        pi_proj=True,
    ).to("cuda")

    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        x = torch.randn(3, 50, 7, 64).to(torch.bfloat16).to("cuda")
        q = torch.randn(3, 50, 7, 64).to(torch.bfloat16).to("cuda")
        k = torch.randn(3, 50, 7, 64).to(torch.bfloat16).to("cuda")
        rope(q, k, x)
