from typing import final, override
import torch
import torch.nn as nn
import torch.nn.functional as F
from atom.training.config_options import PositionalEncodingType
from atom.atom.positional_encodings import TemporalRoPE, RoPE


@final
class QuadraticHeterogenousCrossAttention(nn.Module):
    def __init__(
        self,
        lifting_dim: int,
        num_heads: int,
        num_timesteps: int,
        positional_encoding: PositionalEncodingType,
        rope_base: float,
        rope_tau: float = 1000.0,
        attention_dropout: float = 0.2,
    ) -> None:
        """
        Heterogenous graph cross attention.

        Constructs separate K/V projections for each heterogeneous feature,
        then performs cross attention on queries generated from the q_data ("trunk").

        RoPE is optional; if use_rope=True, it is applied to Q and K.

        Parameters
        ----------
        num_hetero_feats : int
            Number of heterogeneous features.
        lifting_dim : int
            Dimension for Q, K, V.
        num_heads : int
            Number of attention heads.
        num_timesteps : int
            Number of timesteps, used for RoPE and spherical harmonics.
        use_rope : bool
            If True, apply RoPE to Q and K.
        rope_base : float
            Base for RoPE calculations.
        attention_dropout : float, optional
            Dropout rate for attention weights, by default 0.2.

        Attributes
        ----------
        key : nn.Linear
            Linear layer for key projection.
        value : nn.Linear
            Linear layer for value projection.
        query : nn.Linear
            Linear layer for query projection.
        out_proj : nn.Linear
            Linear layer for output projection.
        attention_denom : torch.Tensor
            Attention denominator.
        feature_weights : nn.Parameter
            Learnable weights for gating heterogeneous features.
        rope : TemporalRoPE, optional
            T-RoPE module.

        Raises
        ------
        AssertionError
            If `d_head` (lifting_dim / num_heads) is not even.
        """
        super().__init__()

        self.num_heads = num_heads
        self.lifting_dim = lifting_dim
        self.num_timesteps = num_timesteps
        self.rope_base = rope_base
        self.rope_tau = rope_tau
        self.d_head = self.lifting_dim // self.num_heads

        assert self.d_head % 2 == 0, "d_head must be even"

        self.key = nn.Linear(lifting_dim, lifting_dim)
        self.value = nn.Linear(lifting_dim, lifting_dim)
        self.query = nn.Linear(lifting_dim, lifting_dim)
        self.out_proj = nn.Linear(lifting_dim, lifting_dim)

        self.attention_dropout = nn.Dropout(attention_dropout)
        # Fixed attention denominator sqrt(d_head)
        self.sqrt_dhead: float = float(self.d_head) ** 0.5

        self.feature_weights = nn.Parameter(torch.randn(3) * 0.1)

        self.positional_encoding_type = positional_encoding
        self.positional_encoding: nn.Module | None = None
        match positional_encoding:
            case PositionalEncodingType.TROPE:
                self.positional_encoding = TemporalRoPE(num_timesteps=self.num_timesteps, d_head=self.d_head, n_heads=self.num_heads, base=self.rope_base, tau=self.rope_tau)
            case PositionalEncodingType.ROPE:
                self.positional_encoding = RoPE(d_head=self.d_head, n_heads=self.num_heads, base=self.rope_base, learnable_offset=False)
            case PositionalEncodingType.SINUSOIDAL:
                self.positional_encoding = None
            case PositionalEncodingType.NONE:
                self.positional_encoding = None
            case _:
                raise ValueError(f"Invalid positional encoding type: {positional_encoding}")

    @override
    def forward(
        self,
        x_0: torch.Tensor,
        v_0: torch.Tensor | None,
        concatenated_features: torch.Tensor | None,
        q_data: torch.Tensor,
        mask: torch.Tensor | None,
        time_increments: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Performs heterogeneous cross-attention with multiple feature types.

        Parameters
        ----------
        x_0 : torch.Tensor
            Position features of shape `[B, T, N, d]`.
        v_0 : torch.Tensor | None
            Velocity features of shape `[B, T, N, d]` or None.
        concatenated_features : torch.Tensor | None
            Additional features of shape `[B, T, N, d]` or None.
        q_data : torch.Tensor
            Query data of shape `[B, T, N, d]`.
        mask : torch.Tensor | None, optional
            Mask of shape `[B, T, N, 1]` for padding, by default None.

        Returns
        -------
        torch.Tensor
            Output tensor of shape `[B, T, N, d]`.

        Notes
        -----
        Process:
            1. Flatten query data from `[B, T, N, d]` to `[B, N * T (seq_q), d]`.
            2. Project query to `[B, heads, T*N, d_head]`.
            3. For each heterogeneous feature (x_0, v_0, concatenated_features):
               - Project to K/V of shape `[B, heads, T*N, d_head]`.
               - Apply RoPE if enabled.
               - Compute attention scores `Q·K^T / attention_denom`.
               - Compute attention weights and multiply by V.
               - Gate and accumulate to output.
            4. Reshape output to `[B, T, N, d]`.
        """
        # Flatten Q data: [B, T, N, d] -> [B, N * T (seq_q), d]
        B, T, N, d = q_data.shape
        q_data_flat = q_data.view(B, T * N, d)

        key_mask_for_scores: torch.Tensor | None = None
        rope_mask_for_rope: torch.Tensor | None = None
        if mask is not None:
            # Mask in shape: [B, T, N, 1]; need to mask attention of shape [B, heads, T*N, T*N]
            assert mask.shape == (B, T, N, 1), f"Expected mask shape (B,T,N,1) but got {mask.shape}"
            reshaped_mask = mask.reshape(B, T * N)
            key_mask_for_scores = reshaped_mask.unsqueeze(1).unsqueeze(1)  # [B, 1, 1, T*N] for attention scores
            rope_mask_for_rope = reshaped_mask.unsqueeze(1).unsqueeze(-1)  # [B, 1, T*N, 1] for RoPE

        # No additive sinusoidal PE in attention

        # Project Q => [B, num_heads, N*T, d_head]
        q_proj: torch.Tensor = self.query(q_data_flat).view(B, T * N, self.num_heads, self.d_head).permute(0, 2, 1, 3)  # [B, heads, seq_q, d_head]

        # Apply RoPE-like PE after projection if configured
        if self.positional_encoding_type == PositionalEncodingType.TROPE and self.positional_encoding is not None:
            q_proj = self.positional_encoding(q_proj, rope_mask_for_rope, time_increments)
        elif self.positional_encoding_type == PositionalEncodingType.ROPE and self.positional_encoding is not None:
            q_proj = self.positional_encoding(q_proj, rope_mask_for_rope)

        # We'll accumulate over multiple heterogeneous features
        accumulated_out = torch.zeros_like(q_proj)

        # Collect the features of shape [B, N*T, d]
        hetero_features: list[torch.Tensor | None] = [
            x_0.view(B, T * N, d) if x_0 is not None else None,
            v_0.view(B, T * N, d) if v_0 is not None else None,
            concatenated_features.view(B, T * N, d) if concatenated_features is not None else None,
        ]

        gates = F.softmax(self.feature_weights, dim=0)  # Precompute gates; ∑ gates = 1
        for i, h_feat_flat in enumerate(hetero_features):
            if h_feat_flat is None:
                continue

            # No additive sinusoidal PE in attention

            # h_feat_flat.shape should be [B, T*N, d]
            assert h_feat_flat.shape == (B, T * N, self.lifting_dim), f"Expected shape (B, T*N, d) as {B, T * N, self.lifting_dim} but got {h_feat_flat.shape}"

            # Project K and V => [B, heads, seq_k, d_head]
            k_proj_i: torch.Tensor = self.key(h_feat_flat).view(B, N * T, self.num_heads, self.d_head).permute(0, 2, 1, 3)
            v_proj_i: torch.Tensor = self.value(h_feat_flat).view(B, N * T, self.num_heads, self.d_head).permute(0, 2, 1, 3)

            # Apply RoPE-like PE after projection if configured
            if self.positional_encoding_type == PositionalEncodingType.TROPE and self.positional_encoding is not None:
                k_proj_i = self.positional_encoding(k_proj_i, rope_mask_for_rope, time_increments)
            elif self.positional_encoding_type == PositionalEncodingType.ROPE and self.positional_encoding is not None:
                k_proj_i = self.positional_encoding(k_proj_i, rope_mask_for_rope)

            # 1) scores = Q·K^T / sqrt(d_head)
            scores = (q_proj @ k_proj_i.transpose(-2, -1)) / self.sqrt_dhead
            if key_mask_for_scores is not None:
                pass
                # scores shape is [B, heads, seq_q, seq_k] = [B, heads, T*N, T*N]
                scores = scores.masked_fill(key_mask_for_scores == 0, float("-inf"))

            # 2) softmax over seq_k dimension (dim=-1)
            attn_weights: torch.Tensor = self.attention_dropout(F.softmax(scores, dim=-1))
            # feat_i_out = F.scaled_dot_product_attention(q_proj, k_proj_i, v_proj_i, attn_mask=key_mask_for_scores, dropout_p=0.2, is_causal=False)
            # 3) multiply by V
            feat_i_out = attn_weights @ v_proj_i

            # Gate
            accumulated_out = accumulated_out + gates[i] * feat_i_out

        permuted_accumulated_out = accumulated_out.permute(0, 2, 1, 3).reshape(B, T * N, self.lifting_dim)
        final_out_projection: torch.Tensor = self.out_proj(permuted_accumulated_out)
        assert final_out_projection.shape == (B, T * N, self.lifting_dim), f"Expected (B, T*N, d) as {B, T * N, self.lifting_dim} but got {final_out_projection.shape}"
        # Unflatten => [B, T, N, d]
        final_out_reshaped = final_out_projection.view(B, T, N, self.lifting_dim)

        return final_out_reshaped


@final
class LinearHeterogenousCrossAttention(nn.Module):
    def __init__(
        self,
        lifting_dim: int,
        num_heads: int,
        num_timesteps: int,
        positional_encoding: PositionalEncodingType,
        rope_base: float,
        rope_tau: float,
        attention_dropout: float = 0.2,
    ) -> None:
        super().__init__()

        self.num_heads = num_heads
        self.lifting_dim = lifting_dim
        self.num_timesteps = num_timesteps
        self.rope_base = rope_base
        self.d_head = self.lifting_dim // self.num_heads

        assert self.d_head % 2 == 0, "d_head must be even"

        self.key = nn.Linear(lifting_dim, lifting_dim)
        self.value = nn.Linear(lifting_dim, lifting_dim)
        self.query = nn.Linear(lifting_dim, lifting_dim)
        self.out_proj = nn.Linear(lifting_dim, lifting_dim)

        self.attention_dropout = nn.Dropout(attention_dropout)
        self.sqrt_dhead: float = float(self.d_head) ** 0.5

        self.feature_weights = nn.Parameter(torch.randn(3) * 0.1)

        self.positional_encoding_type = positional_encoding
        self.positional_encoding: nn.Module | None = None
        match positional_encoding:
            case PositionalEncodingType.TROPE:
                self.positional_encoding = TemporalRoPE(num_timesteps=self.num_timesteps, d_head=self.d_head, n_heads=self.num_heads, base=self.rope_base, tau=rope_tau)
            case PositionalEncodingType.ROPE:
                self.positional_encoding = RoPE(d_head=self.d_head, n_heads=self.num_heads, base=self.rope_base, learnable_offset=False)
            case PositionalEncodingType.SINUSOIDAL:
                self.positional_encoding = None
            case PositionalEncodingType.NONE:
                self.positional_encoding = None
            case _:
                raise ValueError(f"Invalid positional encoding type: {positional_encoding}")

    @override
    def forward(
        self,
        x_0: torch.Tensor,
        v_0: torch.Tensor | None,
        concatenated_features: torch.Tensor | None,
        q_data: torch.Tensor,
        mask: torch.Tensor | None,
        time_increments: torch.Tensor | None = None,
    ) -> torch.Tensor:
        # Flatten Q data: [B, T, N, d] -> [B, N*T, d]
        B, T, N, d = q_data.shape
        q_data_flat = q_data.view(B, T * N, d)

        rope_mask_for_rope: torch.Tensor | None = None
        if mask is not None:
            assert mask.shape == (B, T, N, 1), f"Expected mask shape (B,T,N,1) but got {mask.shape}"
            reshaped_mask = mask.reshape(B, T * N)
            rope_mask_for_rope = reshaped_mask.unsqueeze(1).unsqueeze(-1)  # [B, 1, T*N, 1]
        else:
            reshaped_mask = None  # type: ignore

        # No additive sinusoidal PE in attention

        # Project Q => [B, heads, seq_q, d_head]
        q_proj: torch.Tensor = self.query(q_data_flat).view(B, T * N, self.num_heads, self.d_head).permute(0, 2, 1, 3)

        # Apply RoPE-like PE after projection if configured
        if self.positional_encoding_type == PositionalEncodingType.TROPE and self.positional_encoding is not None:
            q_proj = self.positional_encoding(q_proj, rope_mask_for_rope, time_increments)
        elif self.positional_encoding_type == PositionalEncodingType.ROPE and self.positional_encoding is not None:
            q_proj = self.positional_encoding(q_proj, rope_mask_for_rope)

        # Linear attention uses softmax over feature dim for q and k
        q_lin = F.softmax(q_proj, dim=-1)
        if reshaped_mask is not None:
            query_mask_expand = reshaped_mask.unsqueeze(1).unsqueeze(-1)  # [B, 1, seq, 1]
            q_lin = q_lin * query_mask_expand

        accumulated_out = torch.zeros_like(q_lin)

        hetero_features: list[torch.Tensor | None] = [
            x_0.view(B, T * N, d) if x_0 is not None else None,
            v_0.view(B, T * N, d) if v_0 is not None else None,
            concatenated_features.view(B, T * N, d) if concatenated_features is not None else None,
        ]

        gates = F.softmax(self.feature_weights, dim=0)
        for i, h_feat_flat in enumerate(hetero_features):
            if h_feat_flat is None:
                continue

            # No additive sinusoidal PE in attention

            assert h_feat_flat.shape == (B, T * N, self.lifting_dim), f"Expected shape (B, T*N, d) as {B, T * N, self.lifting_dim} but got {h_feat_flat.shape}"

            k_proj_i: torch.Tensor = self.key(h_feat_flat).view(B, N * T, self.num_heads, self.d_head).permute(0, 2, 1, 3)
            v_proj_i: torch.Tensor = self.value(h_feat_flat).view(B, N * T, self.num_heads, self.d_head).permute(0, 2, 1, 3)

            if self.positional_encoding_type == PositionalEncodingType.TROPE and self.positional_encoding is not None:
                k_proj_i = self.positional_encoding(k_proj_i, rope_mask_for_rope, time_increments)
            elif self.positional_encoding_type == PositionalEncodingType.ROPE and self.positional_encoding is not None:
                k_proj_i = self.positional_encoding(k_proj_i, rope_mask_for_rope)

            k_lin = F.softmax(k_proj_i, dim=-1)
            if reshaped_mask is not None:
                key_mask_expand = reshaped_mask.unsqueeze(1).unsqueeze(-1)  # [B,1,seq,1]
                k_lin = k_lin * key_mask_expand
                v_proj_i = v_proj_i * key_mask_expand

            # Compute normalizer D_inv as in provided linear attention: 1 / sum_j q * sum_t k
            k_cumsum = k_lin.sum(dim=-2, keepdim=True)  # sum over sequence length
            eps: float = 1e-6
            D_inv = 1.0 / ((q_lin * k_cumsum).sum(dim=-1, keepdim=True) + eps)

            # q @ (k^T @ v) with dropout on the implicit attention weights via dropout on q
            out_i = (q_lin @ (k_lin.transpose(-2, -1) @ v_proj_i)) * D_inv
            out_i = self.attention_dropout(out_i)

            accumulated_out = accumulated_out + gates[i] * out_i

        permuted_accumulated_out = accumulated_out.permute(0, 2, 1, 3).reshape(B, T * N, self.lifting_dim)
        final_out_projection: torch.Tensor = self.out_proj(permuted_accumulated_out)
        assert final_out_projection.shape == (B, T * N, self.lifting_dim), f"Expected (B, T*N, d) as {B, T * N, self.lifting_dim} but got {final_out_projection.shape}"
        final_out_reshaped = final_out_projection.view(B, T, N, self.lifting_dim)
        return final_out_reshaped


def get_lifting_dim_irreps(lifting_dim: int) -> str:
    """
    Returns the irreps for the lifting dimension.
    """
    vector_lifting_dim_irreps: int = lifting_dim // 3
    scalar_lifting_dim_irreps: int = lifting_dim - vector_lifting_dim_irreps * 3  # Remainder

    lifting_dim_irreps: str = f"{vector_lifting_dim_irreps}x1o + {scalar_lifting_dim_irreps}x0e"
    return lifting_dim_irreps


@final
class QuadraticSelfAttention(nn.Module):
    def __init__(
        self,
        num_heads: int,
        num_timesteps: int,
        lifting_dim: int,
        positional_encoding: PositionalEncodingType,
        rope_base: float,
        rope_tau: float = 1000.0,
        attention_dropout: float = 0.2,
    ) -> None:
        """
        Quadratic self-attention mechanism.

        Parameters
        ----------
        num_heads : int
            Number of attention heads.
        num_timesteps : int
            Number of timesteps, used for RoPE and spherical harmonics.
        lifting_dim : int
            Dimension for Q, K, V.
        use_rope : bool
            If True, apply RoPE to Q and K.
        attention_dropout : float, optional
            Dropout rate for attention weights, by default 0.2.

        Attributes
        ----------
        kv_projs : nn.Linear
            Linear layer for combined key and value projections.
        query : nn.Linear
            Linear layer for query projection.
        out_proj : nn.Linear
            Linear layer for output projection.
        attention_denom : torch.Tensor
            Attention denominator.
        rope : TemporalRoPEWithOffset, optional
            RoPE module.

        Raises
        ------
        AssertionError
            If `d_head` (lifting_dim / num_heads) is not even.
        """
        super().__init__()
        self.num_heads = num_heads
        self.lifting_dim = lifting_dim
        self.num_timesteps = num_timesteps
        self.d_head = self.lifting_dim // self.num_heads

        assert self.d_head % 2 == 0, "d_head must be even"

        self.kv_projs = nn.Linear(lifting_dim, 2 * lifting_dim)
        self.query = nn.Linear(lifting_dim, lifting_dim)
        self.out_proj = nn.Linear(lifting_dim, lifting_dim)
        self.attention_dropout = nn.Dropout(attention_dropout)
        self.sqrt_dhead: float = float(self.d_head) ** 0.5

        self.positional_encoding_type = positional_encoding
        self.positional_encoding: nn.Module | None = None
        self.rope_base = rope_base
        self.rope_tau = rope_tau
        match positional_encoding:
            case PositionalEncodingType.TROPE:
                self.positional_encoding = TemporalRoPE(num_timesteps=self.num_timesteps, d_head=self.d_head, n_heads=self.num_heads, base=self.rope_base, tau=self.rope_tau)
            case PositionalEncodingType.ROPE:
                self.positional_encoding = RoPE(d_head=self.d_head, n_heads=self.num_heads, base=self.rope_base)
            case PositionalEncodingType.SINUSOIDAL:
                self.positional_encoding = None
            case PositionalEncodingType.NONE:
                self.positional_encoding = None
            case _:
                raise ValueError(f"Invalid positional encoding type: {positional_encoding}")

    @override
    def forward(self, tensor: torch.Tensor, mask: torch.Tensor | None, time_increments: torch.Tensor | None = None) -> torch.Tensor:
        """Performs self-attention on an input tensor.

        Parameters
        ----------
        tensor : torch.Tensor
            Input tensor of shape `[B, T, N, d]`.
            - `B` = batch size
            - `T` = number of timesteps
            - `N` = number of nodes
            - `d` = feature dimension
        mask : torch.Tensor | None, optional
            Mask of shape `[B, T, N, 1]` to mask attention scores, by default None.

        Returns
        -------
        torch.Tensor
            Output tensor of shape `[B, T, N, d]`.

        Notes
        -----
        Process:
            1. Flatten input from `[B, T, N, d]` to `[B, T*N, d]`.
            2. Project to Q, K, V of shape `[B, heads, T*N, d_head]`.
            3. Apply RoPE to Q and K if enabled.
            4. Compute attention scores `Q·K^T / attention_denom`.
            5. Apply mask and spherical harmonics bias if enabled.
            6. Compute attention weights and multiply by V.
            7. Reshape output to `[B, T, N, d]`.
        """
        B, T, N, d = tensor.shape
        tensor_flat = tensor.view(B, T * N, d)

        key_mask_for_scores: torch.Tensor | None = None
        rope_mask_for_rope: torch.Tensor | None = None
        if mask is not None:
            assert mask.shape == (B, T, N, 1), f"Expected mask shape (B,T,N,1) but got {mask.shape}"
            reshaped_mask = mask.reshape(B, T * N)
            key_mask_for_scores = reshaped_mask.unsqueeze(1).unsqueeze(1)  # [B, 1, 1, T*N] for attention scores
            rope_mask_for_rope = reshaped_mask.unsqueeze(1).unsqueeze(-1)  # [B, 1, T*N, 1] for RoPE

        if self.positional_encoding_type == PositionalEncodingType.SINUSOIDAL and self.positional_encoding is not None:
            tensor_flat = self.positional_encoding(tensor_flat)

        q_proj: torch.Tensor = self.query(tensor_flat).view(B, T * N, self.num_heads, self.d_head).permute(0, 2, 1, 3)

        if self.positional_encoding_type == PositionalEncodingType.TROPE and self.positional_encoding is not None:
            q_proj = self.positional_encoding(q_proj, rope_mask_for_rope, time_increments)
        elif self.positional_encoding_type == PositionalEncodingType.ROPE and self.positional_encoding is not None:
            q_proj = self.positional_encoding(q_proj, rope_mask_for_rope)

        kv: torch.Tensor = self.kv_projs(tensor_flat)
        k_proj, v_proj = torch.chunk(kv, 2, dim=-1)
        k_proj = k_proj.view(B, N * T, self.num_heads, self.d_head).permute(0, 2, 1, 3)
        v_proj = v_proj.view(B, N * T, self.num_heads, self.d_head).permute(0, 2, 1, 3)

        if self.positional_encoding_type == PositionalEncodingType.TROPE and self.positional_encoding is not None:
            k_proj = self.positional_encoding(k_proj, rope_mask_for_rope, time_increments)
        elif self.positional_encoding_type == PositionalEncodingType.ROPE and self.positional_encoding is not None:
            k_proj = self.positional_encoding(k_proj, rope_mask_for_rope)

        scores: torch.Tensor = (q_proj @ k_proj.transpose(-2, -1)) / self.sqrt_dhead
        if key_mask_for_scores is not None:
            scores = scores.masked_fill(key_mask_for_scores == 0, float("-inf"))

        attn_weights: torch.Tensor = self.attention_dropout(F.softmax(scores, dim=-1))
        processed_out = attn_weights @ v_proj

        permuted_processed_out = processed_out.permute(0, 2, 1, 3).reshape(B, T * N, self.lifting_dim)
        final_out_projection: torch.Tensor = self.out_proj(permuted_processed_out).view(B, T, N, self.lifting_dim)
        return final_out_projection
