import torch
import torch.nn as nn
from torch import Tensor
from typing import Optional, Union, Callable
import torch.nn.functional as F

# Assumes these modules are in your project structure
from .conv_fourier import TetraFourierConv
from .linear_fourier import TetraFourierLinear, TetraFourierLinearQuarterBatch
from .groups import PLATONIC_GROUPS
from .block import PlatonicBlock


def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
    """Drop paths (Stochastic Depth) per sample."""
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor

class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""
    def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

class TetraFourierRMSNormQuarterBatch(nn.Module):
    def __init__(self, channels, eps=None, elementwise_affine=True, device=None, dtype=None):
        super().__init__()
        self.channels = channels
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        if elementwise_affine:
            self.weight1 = nn.Parameter(torch.empty(channels))
            self.weight2 = nn.Parameter(torch.empty(channels))
            self.weight3 = nn.Parameter(torch.empty(3*channels))
        self.reset_parameters()

    def reset_parameters(self):
        if self.elementwise_affine:
            nn.init.ones_(self.weight1)
            nn.init.ones_(self.weight2)
            nn.init.ones_(self.weight3)

    def get_weight(self):
        if not self.elementwise_affine:
            return None
        return torch.vstack((
            torch.cat((self.weight1, self.weight2, self.weight2))[None],
            self.weight3[None].expand(3, -1),
        ))

    def forward(self, x):
        return F.rms_norm(x, (4, 3*self.channels), weight=self.get_weight(), eps=self.eps)

class TetraFourierBlock(nn.Module):
    """
    Args:
        d_model (int): The total model dimension (G * C_model). Must be divisible
                       by group size and (group_size * nhead).
        nhead (int): The number of base attention heads for the interaction layer.
        dim_feedforward (int): The total dimension of the feed-forward network's
                               hidden layer (G * C_ffn). Must be divisible by G.
        solid_name (str): The name of the Platonic solid ('tetrahedron', 'octahedron',
                          'icosahedron') to define the symmetry group.
        dropout (float): Dropout rate.
        activation (Callable): The activation function for the FFN.
        rms_norm_eps (float): Epsilon for RMSNorm.
        norm_first (bool): If True, applies pre-normalization; otherwise, post-normalization.
        spatial_dims (int): The number of spatial dimensions for positions.
        drop_path (float): Stochastic depth rate. Default: 0.0.
        layer_scale_init_value (Optional[float]): Initial value for LayerScale. If None,
                                                  LayerScale is not used. Default: None.
        **kwargs: Additional keyword arguments for the PlatonicConv layer
                  (e.g., freq_sigma, learned_freqs, avg_pool).
    """
    def __init__(self,
                 d_model: int,
                 nhead: int,
                 dim_feedforward: int,
                 dropout: float = 0.1,
                 activation: Callable[[Tensor], Tensor] = F.gelu,
                 rms_norm_eps: float = 1e-5,
                 norm_first: bool = True,
                 spatial_dims: int = 3,
                 drop_path: float = 0.0,
                 layer_scale_init_value: Optional[float] = None,
                 freq_sigma: float = 1.0,
                 learned_freqs: bool = True,
                 mean_aggregation: bool = False,
                 attention: bool = False,
                 attention_type: str = 'equivariant',
                 fourier_type: str = "quarter_batch",
                 **kwargs) -> None:
        super().__init__()

        # --- Group and Dimension Setup ---
        solid_name = "tetrahedron"
        self.group = PLATONIC_GROUPS[solid_name]
        self.num_G = self.group.G
        self.norm_first = norm_first

        # Validate total dimensions against group size and heads
        if d_model % self.num_G != 0:
            raise ValueError(f"d_model ({d_model}) must be divisible by group size ({self.num_G}).")
        if dim_feedforward % self.num_G != 0:
            raise ValueError(f"dim_feedforward ({dim_feedforward}) must be divisible by group size ({self.num_G}).")
        if d_model % (nhead) != 0:
            raise ValueError(f"d_model ({d_model}) must be divisible by num_head = {nhead}.")

        if fourier_type == 'standard':
            raise NotImplementedError("TODO")
            LinearLayer = TetraFourierLinear
        elif fourier_type == 'quarter_batch':
            LinearLayer = TetraFourierLinearQuarterBatch
        else:
            raise ValueError()
        self.fourier_type = fourier_type
        
        # Calculate per-group-element dimensions
        self.dim_per_g = d_model // self.num_G
    
        # --- Equivariant Sub-Modules ---
        self.interaction = TetraFourierConv(
            in_channels=d_model,
            out_channels=d_model,
            embed_dim=d_model,
            num_heads=nhead,
            spatial_dims=spatial_dims,
            freq_sigma=freq_sigma,
            learned_freqs=learned_freqs,
            mean_aggregation=mean_aggregation,
            attention=attention,
            attention_type=attention_type,
            fourier_type=fourier_type,
            **kwargs
        )

        # Equivariant Feed-Forward Network
        self.linear1 = LinearLayer(d_model, dim_feedforward, transform_to_fourier=False, transform_back_from_fourier=True)
        self.linear2 = LinearLayer(dim_feedforward, d_model, transform_to_fourier=True, transform_back_from_fourier=False)

        # Normalization (acts on each quarter of channels dimension)
        # TODO: potentially implement equivariant layer normalization in the irrep basis,
        # RMSNorm is simpler though.
        self.norm1 = TetraFourierRMSNormQuarterBatch(self.dim_per_g, eps=rms_norm_eps)
        self.norm2 = TetraFourierRMSNormQuarterBatch(self.dim_per_g, eps=rms_norm_eps)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.ffn_dropout = nn.Dropout(dropout)
        self.activation = activation

        # --- DropPath and LayerScale ---
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        
        if layer_scale_init_value is not None:
            raise NotImplementedError("TODO")
        # TODO: make layerscale compatible with irrep approach, is current implementation equivariant?
        # self.gamma_1 = nn.Parameter(layer_scale_init_value * torch.ones((d_model)), requires_grad=True) if layer_scale_init_value is not None else None
        # self.gamma_2 = nn.Parameter(layer_scale_init_value * torch.ones((d_model)), requires_grad=True) if layer_scale_init_value is not None else None
        self.gamma_1 = None
        self.gamma_2 = None

    def forward(
        self,
        x: Tensor,
        pos: Tensor,
        batch: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        avg_num_nodes = 1.0
    ) -> Tensor:
        """
        Args:
            x (Tensor): Input feature tensor of shape [..., 4, G*C//4].
            pos (Tensor): Position tensor of shape [..., D_spatial].
            batch (Optional[Tensor]): For graph mode. Batch index for each element.
            mask (Optional[Tensor]): For dense mode. Boolean mask.

        Returns:
            Tensor: Output feature tensor of the same shape [..., G*C].
        """
        if self.norm_first:
            # 1. Interaction Block (Pre-Norm)
            interaction_out = self._interaction_block(self.norm1(x), pos, batch, mask, avg_num_nodes)
            if self.gamma_1 is not None:
                interaction_out = self.gamma_1 * interaction_out
            x = x + self.drop_path1(interaction_out)

            # 2. Feed-Forward Block (Pre-Norm)
            ff_output = self._ff_block(self.norm2(x))
            if self.gamma_2 is not None:
                ff_output = self.gamma_2 * ff_output
            x = x + self.drop_path2(ff_output)
        else:
            # 1. Interaction Block (Post-Norm)
            interaction_out = self._interaction_block(x, pos, batch, mask, avg_num_nodes)
            if self.gamma_1 is not None:
                interaction_out = self.gamma_1 * interaction_out
            x = self.norm1(x + self.drop_path1(interaction_out))

            # 2. Feed-Forward Block (Post-Norm)
            ff_output = self._ff_block(x)
            if self.gamma_2 is not None:
                ff_output = self.gamma_2 * ff_output
            x = self.norm2(x + self.drop_path2(ff_output))
        
        return x

    def _interaction_block(
        self, x: Tensor, pos: Tensor, batch: Optional[Tensor], mask: Optional[Tensor], avg_num_nodes = 1.0
    ) -> Tensor:
        """Wrapper for the PlatonicConv layer."""
        interaction_output = self.interaction(x, pos, batch=batch, mask=mask, avg_num_nodes=avg_num_nodes)
        return self.dropout1(interaction_output)

    def _ff_block(self, x: Tensor) -> Tensor:
        """Equivariant Feed-Forward Network block."""
        ff_output = self.linear2(self.ffn_dropout(self.activation(self.linear1(x))))
        return self.dropout2(ff_output)

def test_equivalence():
    B = 16
    N = 32
    in_c = 768
    embed_dim = 288
    num_heads = 4*12
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    with torch.inference_mode():
        x = torch.randn([B, N, in_c], device=device)
        pos = torch.randn([B, N, 3], device=device)
        
        block = PlatonicBlock(
            d_model=in_c,
            nhead=num_heads,
            dim_feedforward=embed_dim,
            solid_name="tetrahedron",
        ).to(device)
        block.norm1 = nn.Identity()  # Norm layers are currently not equivalent between implementations
        block.norm2 = nn.Identity()
        block.eval()
        y = block(x, pos)
        print(y[5,2,:32])

        block_f2 = TetraFourierBlock(
            d_model=in_c,
            nhead=num_heads,
            dim_feedforward=embed_dim,
            fourier_type="quarter_batch",
        ).to(device)
        block_f2.interaction.init_from_non_fourier(block.interaction)
        block_f2.linear1.reset_parameters(block.linear1)
        block_f2.linear2.reset_parameters(block.linear2)
        block_f2.norm1 = nn.Identity()  # Norm layers are currently not equivalent between implementations
        block_f2.norm2 = nn.Identity()
        block_f2.eval()
        y_f2 = block_f2.interaction.q_proj.from_fourier(block_f2(block_f2.interaction.out_proj.to_fourier(x), pos))
        print(y_f2[5,2,:32])
        assert torch.allclose(y, y_f2, atol=1e-4), "Outputs should be equal"

    print("Tests passed!")

if __name__ == "__main__":
    test_equivalence()
