import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch import Tensor
from typing import Optional
import math

from torch_scatter import scatter_softmax,scatter_sum

try:
    from torch_cluster import knn_graph  
except ImportError:
    knn_graph = None

# Assumes these modules are in your project structure
from .utils import scatter_add
from .rope import PlatonicRoPE
from .groups import PLATONIC_GROUPS
from .addrope import AddRoPE
from .conv import PlatonicConv
from .linear_fourier import TetraFourierLinear, TetraFourierLinearQuarterBatch
from .utils_fourier import ToTetraFourier, FromTetraFourier, ToTetraFourierQuarterBatch, FromTetraFourierQuarterBatch


class TetraFourierConv(nn.Module):
    """
    Computes a group-equivariant dynamic convolution supporting both graph and dense modes.

    This layer uses Rotary Positional Embeddings (RoPE) to compute a dynamic
    convolution kernel. It supports two modes for dense data:
    1.  attention=False (Default): A highly efficient linear convolutio type "attention" mechanism.
    2.  attention=True: Standard scaled dot-product attention with softmax.
    
    Graph-structured data only uses the linear attention mechanism.
    The layer is equivariant to the symmetries of a specified Platonic solid.
    """
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        embed_dim: int,
        num_heads: int,
        spatial_dims: int = 3,
        freq_sigma: float = 1.0,
        learned_freqs: bool = True,
        bias: bool = True,
        mean_aggregation: bool = False,
        attention: bool = False,
        attention_type: str = 'equivariant',
        fourier_type: str = 'quarter_batch',
    ):
        super().__init__()

        # --- Group Setup ---
        solid_name = "tetrahedron"
        self.group = PLATONIC_GROUPS[solid_name]
        self.num_G = self.group.G
        
        # --- Dimension Validation and Setup ---
        if in_channels % self.num_G != 0:
            raise ValueError(f"in_channels ({in_channels}) must be divisible by group size ({self.num_G}).")
        self.in_channels_g = in_channels // self.num_G

        if out_channels % self.num_G != 0:
            raise ValueError(f"out_channels ({out_channels}) must be divisible by group size ({self.num_G}).")
        self.out_channels_g = out_channels // self.num_G

        if num_heads % self.num_G != 0:
            raise ValueError(f"num_heads ({num_heads}) must be divisible by group size ({self.num_G}).")

        if embed_dim % (num_heads // self.num_G) != 0:
             raise ValueError(f"embed_dim ({embed_dim}) must be divisible by (num_heads // num_group) = "
                             f"{self.num_G // num_heads}.")
        if fourier_type == 'standard':
            LinearLayer = TetraFourierLinear
        elif fourier_type == 'quarter_batch':
            LinearLayer = TetraFourierLinearQuarterBatch
        else:
            raise ValueError()
        self.fourier_type = fourier_type

        self.embed_dim = embed_dim
        self.embed_dim_g = embed_dim // self.num_G
       
        self.out_channels = out_channels
        self.effective_num_heads = num_heads//self.num_G
        self.head_dim = self.embed_dim_g // self.effective_num_heads

        self.mean_aggregation = mean_aggregation
        self.attention = attention
        self.attention_type = attention_type

        # --- Sub-modules ---
        self.q_proj = LinearLayer(in_channels, embed_dim, bias=bias, transform_to_fourier=False, transform_back_from_fourier=True)
        self.v_proj = LinearLayer(in_channels, embed_dim, bias=bias, transform_to_fourier=False, transform_back_from_fourier=True)
        if freq_sigma is None:
            self.k_proj = LinearLayer(in_channels, embed_dim, bias=bias, transform_to_fourier=False, transform_back_from_fourier=True)
        else:
            self.register_buffer('k_proj', None)

        # Group-equivariant RoPE for positional information
        if freq_sigma is not None:
            self.rope_emb = PlatonicRoPE(
                embed_dim=embed_dim,
                num_heads=self.effective_num_heads,
                head_dim=self.head_dim,
                solid_name=solid_name,
                spatial_dims=spatial_dims,
                freq_sigma=freq_sigma,
                learned_freqs=learned_freqs,
            )
        else:
            self.register_buffer('rope_emb', None)

        # Final equivariant linear layer
        self.out_proj = LinearLayer(embed_dim, out_channels, bias=bias, transform_to_fourier=True, transform_back_from_fourier=False)

    def init_from_non_fourier(self, conv: PlatonicConv):
        # TODO: Does not fail gracefully if the provided conv is not compatible...
        if conv.group.solid_name != self.group.solid_name:
            raise ValueError()
        self.q_proj.reset_parameters(conv.q_proj)
        self.v_proj.reset_parameters(conv.v_proj)
        if self.k_proj is not None:
            self.k_proj.reset_parameters(conv.k_proj)
        self.out_proj.reset_parameters(conv.out_proj)
        if self.rope_emb is not None:
            with torch.no_grad():
                self.rope_emb.freqs.copy_(conv.rope_emb.freqs.clone().detach())

    
    def _forward_shared(self, x: Tensor, pos: Tensor):
        """Shared logic for projections and RoPE application."""
        q_raw = self.q_proj(x)
        v_raw = self.v_proj(x)
        # If not using RoPE, then project, but if using RoPE then use ones
        k_raw = self.k_proj(x) if self.rope_emb is None else torch.ones_like(q_raw)

        leading_dims = q_raw.shape[:-1]

        # Reshape for multi-head processing: [..., G * H * D_h] -> [..., G, H, D_h]
        q = q_raw.view(*leading_dims, self.num_G, self.effective_num_heads, self.head_dim)
        v = v_raw.view(*leading_dims, self.num_G, self.effective_num_heads, self.head_dim)
        k = k_raw.view(*leading_dims, self.num_G, self.effective_num_heads, self.head_dim)

        # Apply RoPE to query and key
        if self.rope_emb is not None:
            q = self.rope_emb(q, pos)
            k = self.rope_emb(k, pos)

        return q, k, v

    # scatter_softmax kernels don't play well with torch.compile
    @torch.compiler.disable()
    def _scatter_softmax(self, scores, group_ids, dim, dim_size):
        return scatter_softmax(scores, group_ids, dim=dim, dim_size=dim_size)

    def graph_scattered_attention(self,
        q: torch.Tensor,      # [N, G, H, D]
        k: torch.Tensor,      # [N, G, H, D]
        v: torch.Tensor,      # [N, G, H, D]
        batch: torch.Tensor,  # [N]
        pos: torch.Tensor | None = None,     
        edge_index: torch.Tensor | None = None,
        k_knn: int | None = None
    ) -> torch.Tensor:
        """
        Compute full connected edge if edge_index is None, or kNN edges if k_knn is given.
        Supports both equivariant and invariant attention modes.
        
        Returns
        -------
        out : Tensor, shape [N, G*H*D]
        """
        N, G, H, D = q.shape
        device = q.device

        if edge_index is not None:
            src, dst = edge_index.to(device)
        elif k_knn is not None:
            if pos is None:
                raise ValueError("k_knn was given but 'pos' is None.")
            if knn_graph is None:
                raise ImportError("torch_cluster.knn_graph is required for kNN mode.")
            edge_index = knn_graph(
                x=pos.to(device), k=k_knn, batch=batch, loop=True
            )                                        # [2, |E_knn|]
            src, dst = edge_index
        else:
            N = batch.shape[0]
            # Create a dense N x N grid of all possible edges
            node_idx = torch.arange(N, device=device)
            src, dst = torch.meshgrid(node_idx, node_idx, indexing='ij')

            # Keep only the edges where the source and destination nodes
            # belong to the same graph in the batch.
            mask = batch[src] == batch[dst]
            
            # Apply the mask to get the final edge index
            src = src[mask]
            dst = dst[mask]

        E = src.numel()
        
        if self.attention_type == 'invariant':
            # Permute to [N, H, G, D] and reshape to [N, H, G*D]
            q_inv = q.permute(0, 2, 1, 3).reshape(N, H, G * D)
            k_inv = k.permute(0, 2, 1, 3).reshape(N, H, G * D)
            v_inv = v.permute(0, 2, 1, 3).reshape(N, H, G * D)
            
            q_src = q_inv[src]  # [E, H, G*D]
            k_dst = k_inv[dst]  # [E, H, G*D]
            v_dst = v_inv[dst]  # [E, H, G*D]
            
            # Integration over h_dim and also G (so invariant)
            scores = (q_src * k_dst).sum(-1) * ((G * D) ** -0.5)  # [E, H]
            
            head_ids = torch.arange(H, device=device).repeat(E, 1)       # [E, H]
            node_head_ids = src.unsqueeze(1) * H + head_ids              # [E, H]
            
            # Apply softmax normalization
            a = self._scatter_softmax(
                scores.flatten(),
                node_head_ids.flatten(),
                dim=0,
                dim_size=N * H
            ).view(E, H)                                             
            
            weighted = (a.unsqueeze(-1) * v_dst)                        # [E, H, G*D]
            weighted = weighted.reshape(-1, G*D)                        # [E*H, G*D]
            
            out = scatter_sum(
                weighted,
                node_head_ids.flatten(),
                dim=0,
                dim_size=N * H
            ).view(N, H, G*D)                                          

            # Reshape (N, H, G*D) -> (N, G, H, D) -> (N, G*H*D)
            out = out.reshape(N, H, G, D).permute(0, 2, 1, 3).reshape(N, G*H*D)

        elif self.attention_type == 'equivariant':
            
            GH = G * H  
            q_src = q.reshape(N, GH, D)[src]  # [E, GH, D]
            k_dst = k.reshape(N, GH, D)[dst]  # [E, GH, D]
            v_dst = v.reshape(N, GH, D)[dst]  # [E, GH, D]

            scores = (q_src * k_dst).sum(-1) * D ** -0.5  # [E, GH]

            # reindex ids for heads
            head_ids = torch.arange(GH, device=device).repeat(E, 1)     # [E, GH]
            group_ids = src.unsqueeze(1) * GH + head_ids                # [E, GH]

            a = self._scatter_softmax(
                scores.flatten(),
                group_ids.flatten(),
                dim=0,
                dim_size=N * GH
            ).view(E, GH)                                              

            weighted = (a.unsqueeze(-1) * v_dst).reshape(-1, D)         # [E*GH, D]

            out = scatter_sum(
                weighted,
                group_ids.flatten(),
                dim=0,
                dim_size=N * GH
            ).view(N, GH, D)                                            

            # Reshape (N, GH, D) -> (N, G*H*D)
            out = out.reshape(N, G*H*D)                                
    
        return out
        


    def _forward_graph(self, x: Tensor, pos: Tensor, batch: Tensor, avg_num_nodes=1.0):
        """
        Implementation for graph-structured data.
        Supports both kernelized linear attention and standard softmax attention.
        """
        q_rope, k_rope, v = self._forward_shared(x, pos) # [N, G, H, D_h]
  
        if self.attention:
            output = self.graph_scattered_attention(q_rope, k_rope, v, batch, pos)
        else:
            kv_outer_product = torch.einsum('nghd,nghe->nghde', k_rope, v)
            num_graphs = batch.max() + 1
            kv_kernel = scatter_add(kv_outer_product, batch, dim_size=num_graphs)

            if self.mean_aggregation:
                num_nodes = scatter_add(torch.ones_like(batch, dtype=torch.float), batch, dim_size=num_graphs)[..., None, None, None, None]
            else:
                num_nodes = avg_num_nodes
            kv_kernel = kv_kernel / num_nodes
            
            output = torch.einsum('nghd,nghde->nghe', q_rope, kv_kernel[batch])
            output = output.flatten(-3, -1) # -> (..., G, H, H_dim) -> (..., G*H*H_dim)

        return self.out_proj(output)

    # TODO: Below does not work well with torch.compile for unknown reason
    @torch.compiler.disable()
    def _conv_dense(self, q_rope, k_rope, v, mask, batch_size, sequence_length, avg_num_nodes):
        B, S = batch_size, sequence_length
        if self.attention:
            if self.attention_type == 'invariant':
                # Invariant: Permute to (B, S, H, G, Dh) and merge G and Dh
                q_perm = q_rope.permute(0, 1, 3, 2, 4).reshape(B, S, self.effective_num_heads, self.num_G * self.head_dim)
                k_perm = k_rope.permute(0, 1, 3, 2, 4).reshape(B, S, self.effective_num_heads, self.num_G * self.head_dim)
                v_perm = v.permute(0, 1, 3, 2, 4).reshape(B, S, self.effective_num_heads, self.num_G * self.head_dim)
                # Reshape for SDPA: (B, S, H, G*Dh) -> (B, H, S, G*Dh)
                q_sdpa = q_perm.transpose(1, 2)
                k_sdpa = k_perm.transpose(1, 2)
                v_sdpa = v_perm.transpose(1, 2)
            elif self.attention_type == 'equivariant':
                # Reshape for scaled_dot_product_attention: (B, S, G, H, Dh) -> (B, G*H, S, Dh)
                q_sdpa = q_rope.view(B, S, self.num_G * self.effective_num_heads, self.head_dim).transpose(1, 2)
                k_sdpa = k_rope.view(B, S, self.num_G * self.effective_num_heads, self.head_dim).transpose(1, 2)
                v_sdpa = v.view(B, S, self.num_G * self.effective_num_heads, self.head_dim).transpose(1, 2)
            else:
                raise ValueError(f"Unknown attention head type: {self.attention_type}. Should be 'invariant' or 'equivariant'.")

            attn_mask = mask[:, None, None, :] if mask is not None else None
            with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):      
               attn_output = F.scaled_dot_product_attention(q_sdpa, k_sdpa, v_sdpa, attn_mask=attn_mask)

            # Reshape back for output projection: (B, G*H, S, Dh)/(B, H, S, G*Dh) -> (B, S, G*H*Dh)
            return attn_output.transpose(1, 2).reshape(B, S, self.embed_dim)
        else:
            if mask is not None:
                # Apply mask before aggregation
                v = v * mask[..., None, None, None]
                k_rope = k_rope * mask[..., None, None, None]

            kv_kernel = torch.einsum('bsghd,bsghe->bghde', k_rope, v)

            if self.mean_aggregation and mask is not None:
                num_nodes = mask.sum(dim=-1).float().view(B, 1, 1, 1, 1)
            else:
                num_nodes = avg_num_nodes
            kv_kernel = kv_kernel / num_nodes

            output = torch.einsum('bsghd,bghde->bsghe', q_rope, kv_kernel)
            return output.flatten(-3, -1)

    def _forward_dense(self, x: Tensor, pos: Tensor, mask: Tensor, avg_num_nodes=1.0):
        """
        Implementation for dense, padded data.
        Supports both linear and standard softmax attention.
        """
        q_rope, k_rope, v = self._forward_shared(x, pos)
        B, S = x.shape[:2] # B: batch size, S: sequence length

        # if self.attention:
        #     if self.attention_type == 'invariant':
        #         # Invariant: Permute to (B, S, H, G, Dh) and merge G and Dh
        #         q_perm = q_rope.permute(0, 1, 3, 2, 4).reshape(B, S, self.effective_num_heads, self.num_G * self.head_dim)
        #         k_perm = k_rope.permute(0, 1, 3, 2, 4).reshape(B, S, self.effective_num_heads, self.num_G * self.head_dim)
        #         v_perm = v.permute(0, 1, 3, 2, 4).reshape(B, S, self.effective_num_heads, self.num_G * self.head_dim)
        #         # Reshape for SDPA: (B, S, H, G*Dh) -> (B, H, S, G*Dh)
        #         q_sdpa = q_perm.transpose(1, 2)
        #         k_sdpa = k_perm.transpose(1, 2)
        #         v_sdpa = v_perm.transpose(1, 2)
        #     elif self.attention_type == 'equivariant':
        #         # Reshape for scaled_dot_product_attention: (B, S, G, H, Dh) -> (B, G*H, S, Dh)
        #         q_sdpa = q_rope.view(B, S, self.num_G * self.effective_num_heads, self.head_dim).transpose(1, 2)
        #         k_sdpa = k_rope.view(B, S, self.num_G * self.effective_num_heads, self.head_dim).transpose(1, 2)
        #         v_sdpa = v.view(B, S, self.num_G * self.effective_num_heads, self.head_dim).transpose(1, 2)
        #     else:
        #         raise ValueError(f"Unknown attention head type: {self.attention_type}. Should be 'invariant' or 'equivariant'.")

        #     attn_mask = mask[:, None, None, :] if mask is not None else None
        #     with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):      
        #        attn_output = F.scaled_dot_product_attention(q_sdpa, k_sdpa, v_sdpa, attn_mask=attn_mask)

        #     # Reshape back for output projection: (B, G*H, S, Dh)/(B, H, S, G*Dh) -> (B, S, G*H*Dh)
        #     output = attn_output.transpose(1, 2).reshape(B, S, self.embed_dim)
        # else:
        #     if mask is not None:
        #         # Apply mask before aggregation
        #         v = v * mask[..., None, None, None]
        #         k_rope = k_rope * mask[..., None, None, None]

        #     kv_kernel = torch.einsum('bsghd,bsghe->bghde', k_rope, v)

        #     if self.mean_aggregation and mask is not None:
        #         num_nodes = mask.sum(dim=-1).float().view(B, 1, 1, 1, 1)
        #     else:
        #         num_nodes = avg_num_nodes
        #     kv_kernel = kv_kernel / num_nodes

        #     output = torch.einsum('bsghd,bghde->bsghe', q_rope, kv_kernel)
        #     output = output.flatten(-3, -1)
        output = self._conv_dense(q_rope, k_rope, v, mask, B, S, avg_num_nodes=avg_num_nodes)
        
        return self.out_proj(output)
    
    def forward(
        self,
        x: Tensor,
        pos: Tensor,
        batch: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        avg_num_nodes: Optional[float] = 1.0
    ) -> Tensor:
        is_graph_mode = batch is not None
        avg_num_nodes = avg_num_nodes if avg_num_nodes is not None else 1.0
        
        if is_graph_mode:
            if mask is not None:
                raise ValueError("Only one of 'batch' or 'mask' can be provided.")
            return self._forward_graph(x, pos, batch, avg_num_nodes=avg_num_nodes)
        else:
            return self._forward_dense(x, pos, mask, avg_num_nodes=avg_num_nodes)

def test_equivalence():
    B = 16
    N = 32
    in_c = 768
    out_c = 384
    embed_dim = 288
    num_heads = 4*12
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    x = torch.randn([B, N, in_c], device=device)
    pos = torch.randn([B, N, 3], device=device)

    conv = PlatonicConv(
        in_channels=in_c,
        out_channels=out_c,
        embed_dim=embed_dim,
        num_heads=num_heads,
        solid_name="tetrahedron",
    ).to(device)
    y = conv(x, pos)
    print(y[5,2,:32])

    conv_f1 = TetraFourierConv(
        in_channels=in_c,
        out_channels=out_c,
        embed_dim=embed_dim,
        num_heads=num_heads,
        fourier_type="standard",
    ).to(device)
    conv_f1.init_from_non_fourier(conv)
    y_f1 = conv_f1.q_proj.from_fourier(*conv_f1(conv_f1.out_proj.to_fourier(x), pos))
    print(y_f1[5,2,:32])
    assert torch.allclose(y, y_f1, atol=1e-4), "Outputs should be equal"

    conv_f2 = TetraFourierConv(
        in_channels=in_c,
        out_channels=out_c,
        embed_dim=embed_dim,
        num_heads=num_heads,
        fourier_type="quarter_batch",
    ).to(device)
    conv_f2.init_from_non_fourier(conv)
    y_f2 = conv_f2.q_proj.from_fourier(conv_f2(conv_f2.out_proj.to_fourier(x), pos))
    print(y_f2[5,2,:32])
    assert torch.allclose(y, y_f2, atol=1e-4), "Outputs should be equal"

    print("Tests passed!")

if __name__ == "__main__":
    test_equivalence()
