import flax.linen as nn
import jax.numpy as jnp
from layers.attention import AttentionBlock
from layers.feed_forward import FeedForward
from layers.positional_encoding import PositionalEncoding, NaivePositionalEncoding
from pretrain.mlp import PretrainedPositionalEncoding
from typing import Any

class CrystalFourierTransformer(nn.Module):
    config: dict
    cubic_abc_combinations: jnp.ndarray  # For cubic space groups
    hexagonal_abc_combinations: jnp.ndarray  # For hexagonal space groups
    cubic_adj_matrices: jnp.ndarray  # For cubic space groups
    hexagonal_adj_matrices: jnp.ndarray  # For hexagonal space groups
    cubic_pretrained_state: Any = None  # For cubic space groups
    hexagonal_pretrained_state: Any = None  # For hexagonal space groups
    cubic_encoding_config: dict = None  # For cubic space groups
    hexagonal_encoding_config: dict = None  # For hexagonal space groups

    def setup(self):
        self.atom_embedding = nn.Embed(
            num_embeddings=101,
            features=self.config['embedding_dim']
        )
        
        if self.config['fourier']:
            # Initialize both cubic and hexagonal positional encodings
            self.cubic_positional_encoding = PretrainedPositionalEncoding(
                config=self.cubic_encoding_config,
                abc_combinations=self.cubic_abc_combinations,
                adjacency_matrices=self.cubic_adj_matrices,
                pretrained_state=self.cubic_pretrained_state
            )
            self.hexagonal_positional_encoding = PretrainedPositionalEncoding(
                config=self.hexagonal_encoding_config,
                abc_combinations=self.hexagonal_abc_combinations,
                adjacency_matrices=self.hexagonal_adj_matrices,
                pretrained_state=self.hexagonal_pretrained_state
            )
        else:
            self.positional_encoding = NaivePositionalEncoding(config=self.config)
        
        self.input_norm = nn.BatchNorm()
        self.attention_blocks = [AttentionBlock(self.config) for _ in range(self.config['num_attn_blocks'])]        
        self.final_ff = FeedForward(self.config)

    def __call__(self, atom_numbers, atom_positions, lattice_matrices, space_groups, masks, training=True, rngs=None):
        atom_embeddings = self.atom_embedding(atom_numbers)
        
        # Determine which positional encoding to use based on space group
        is_hexagonal = (space_groups >= 143) & (space_groups <= 194)
        
        if self.config['fourier']:
            # Use cubic encoding for non-hexagonal groups
            cubic_pos_encodings = self.cubic_positional_encoding(
                atom_positions, lattice_matrices, space_groups
            )
            # Use hexagonal encoding for hexagonal groups
            hexagonal_pos_encodings = self.hexagonal_positional_encoding(
                atom_positions, lattice_matrices, space_groups
            )
            # Combine based on space group type
            pos_encodings = jnp.where(
                is_hexagonal[:, None, None],
                hexagonal_pos_encodings,
                cubic_pos_encodings
            )
        else:
            pos_encodings = self.positional_encoding(atom_positions, lattice_matrices, space_groups)
        
        x = atom_embeddings + pos_encodings
        x = self.input_norm(x, use_running_average=not training)

        for block in self.attention_blocks:
            x = block(x, mask=masks, deterministic=not training, rngs=rngs)

        x = x * masks[:, :, None]
        x = jnp.sum(x, axis=1) / jnp.sum(masks, axis=1, keepdims=True)
        return self.final_ff(x)