import torch
import torch.nn as nn
from functools import partial
from einops import rearrange, repeat
from einops.layers.torch import Rearrange


def pair(t):
    return t if isinstance(t, tuple) else (t, t)

def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
    y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
    assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
    omega = torch.arange(dim // 4) / (dim // 4 - 1)
    omega = 1.0 / (temperature ** omega)

    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :]
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
    return pe.type(dtype)

def triple(t):
    return t if isinstance(t, tuple) else (t, t, t)

def posemb_sincos_3d(d, h, w, dim, temperature: int = 10000, dtype = torch.float32):
    """
    3D positional embedding using sine and cosine functions
    
    Args:
        d (int): Depth dimension size
        h (int): Height dimension size
        w (int): Width dimension size
        dim (int): Embedding dimension (must be multiple of 6)
        temperature (int): Temperature for frequency scaling
        dtype (torch.dtype): Data type of the output
        
    Returns:
        torch.Tensor: Positional embeddings of shape (d*h*w, dim)
    """
    z, y, x = torch.meshgrid(torch.arange(d), torch.arange(h), torch.arange(w), indexing="ij")
    assert (dim % 6) == 0, "feature dimension must be multiple of 6 for 3D sincos embedding"
    omega = torch.arange(dim // 6) / (dim // 6 - 1)
    omega = 1.0 / (temperature ** omega)

    z = z.flatten()[:, None] * omega[None, :]
    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :]
    
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim=1)
    return pe.type(dtype)


class TrainableFourierEmbedding(nn.Module):
    def __init__(self, num_embeddings=10, embedding_dim=512, fourier_dim=None, temperature=10000):
        """
        Initialize trainable embeddings with Fourier encoding for 3D data
        
        Args:
            num_embeddings (int): Number of trainable embeddings (default: 10)
            embedding_dim (int): Dimension of each embedding
            fourier_dim (int): Output dimension after Fourier encoding (if None, uses embedding_dim)
            temperature (int): Temperature for Fourier encoding
        """
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.fourier_dim = fourier_dim if fourier_dim is not None else embedding_dim
        self.temperature = temperature
        
        # Define trainable embeddings
        self.embeddings = nn.Parameter(torch.randn(num_embeddings, embedding_dim))
        
    def _apply_fourier_encoding(self, x):
        """Apply Fourier encoding to input tensor x"""
        assert (self.fourier_dim % 6) == 0, "fourier_dim must be multiple of 6 for 3D sincos encoding"
        
        # Calculate frequencies
        omega = torch.arange(self.fourier_dim // 6, device=x.device) / (self.fourier_dim // 6 - 1)
        omega = 1.0 / (self.temperature ** omega)
        
        # Apply encoding - reshape input for broadcasting
        x_expanded = x.unsqueeze(-1) * omega.unsqueeze(0)
        
        # Compute sin and cos
        sin_x = torch.sin(x_expanded)
        cos_x = torch.cos(x_expanded)
        
        # Interleave sin and cos for each dimension (now we need 3 dimensions)
        encoded = torch.cat([sin_x, cos_x], dim=-1).flatten(-2)
        return encoded
    
    def forward(self, idx, depth, height, width):
        """
        Forward pass for 3D data
        
        Args:
            idx (torch.Tensor): Indices of embeddings to use, shape (b,)
            depth (int): Depth of output feature volume
            height (int): Height of output feature volume
            width (int): Width of output feature volume
            
        Returns:
            torch.Tensor: Features of shape (b, c, d, h, w)
        """
        # Get embeddings for the given indices [b, embedding_dim]
        emb = self.embeddings[idx]
        
        # Apply Fourier encoding
        encoded = self._apply_fourier_encoding(emb)  # [b, fourier_dim]
        
        # Reshape to (b, c, d, h, w)
        total_voxels = depth * height * width
        channels = encoded.shape[-1] // total_voxels
        assert channels * total_voxels == encoded.shape[-1], "Dimension mismatch: can't reshape to requested dimensions"
        
        return encoded.view(-1, channels, depth, height, width)


class TrainableFourierEmbedding3D(nn.Module):
    def __init__(self, num_embeddings=10, embedding_dim=512, image_size=[8, 16, 16], temperature=10000):
        """
        Initialize trainable embeddings with Fourier encoding for 3D data
        
        Args:
            num_embeddings (int): Number of trainable embeddings (default: 10)
            embedding_dim (int): Dimension of embeddings and Fourier encoding
            depth (int): Depth of output feature volume
            height (int): Height of output feature volume
            width (int): Width of output feature volume
            temperature (int): Temperature for Fourier encoding
        """
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.temperature = temperature
        self.depth = image_size[0]
        self.height = image_size[1]
        self.width = image_size[2]
        
        # Define trainable embeddings with 3 channels
        self.embeddings = nn.Embedding(num_embeddings, 3 * embedding_dim)
        
        # Pre-compute positional embeddings
        self.register_buffer(
            'pos_emb', 
            posemb_sincos_3d(self.depth, self.height, self.width, embedding_dim, temperature=temperature)
        )
    
    def forward(self, idx):
        """
        Forward pass for 3D data using simple dot product
        
        Args:
            idx (torch.Tensor): Indices of embeddings to use, shape (b, num_directions)
            
        Returns:
            torch.Tensor: Features of shape (b, 3, d, h, w)
        """
        if len(idx.shape) == 1:
            idx = idx.unsqueeze(-1)
            squeeze = True
        else:
            squeeze = False
        batch_size, num_directions = idx.shape
        
        batch_emb = self.embeddings(idx)  # [b, num_directions, 3 * embedding_dim]
        batch_emb = batch_emb.reshape(batch_size, num_directions, 3, self.embedding_dim) # [b, num_directions, 3, embedding_dim]
        pos_emb = self.pos_emb.unsqueeze(0).unsqueeze(1).unsqueeze(2).expand(batch_size, num_directions, 3, -1, -1)
        
        bmm_result = torch.einsum('bdne,bdnme->bdnm', batch_emb, pos_emb)
        
        output = bmm_result.reshape(batch_size, num_directions, 3, self.depth, self.height, self.width)
        if squeeze:
            output = output.squeeze(1)
        return output
    