
import einops
import einx
import torch
import torch.nn as nn


class FoundationEmbeddingX(torch.nn.Module):

    def __init__(
            self,
            dim: int,
            n_features: int,
        ) -> None:
        
        super().__init__()

        self.dim = dim
        self.n_features = n_features

        self.x_embedding = nn.Linear(n_features, dim)

    
    def forward(
        self, 
        x_support: torch.Tensor, 
        x_query__: torch.Tensor, 
    ) -> tuple[torch.Tensor, torch.Tensor]:

        x_support = self.x_embedding(x_support)
        x_query__ = self.x_embedding(x_query__)

        return x_support, x_query__
    

class FoundationQuantileEmbeddingX(torch.nn.Module):

    def __init__(
        self,
        dim: int,
        n_features: int,
        feature_count_scaling: bool,
    ) -> None:
        
        super().__init__()

        self.dim = dim
        self.n_features = n_features
        self.feature_count_scaling = feature_count_scaling

        
    def forward(
        self, 
        x_support: torch.Tensor, 
        x_query__: torch.Tensor, 
        padding_mask: torch.Tensor,
        feature_mask: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        
        """
        Syntax:
        b = batch size
        s = number of observations
        f = number of features
        q = number of quantiles
        """

        batch_size = padding_mask.shape[0]
        seq_len = einx.sum('b [s]', ~padding_mask)
        feature_count = einx.sum('b [f]', ~feature_mask)

        # By setting the padded tokens to 9999 we ensure they don't participate in the quantile calculation
        x_support[padding_mask] = 9999

        q = torch.arange(1, 1000, dtype=torch.float, device=x_support.device) / 1000
        quantiles = torch.quantile(x_support, q=q, dim=1)
        quantiles = einx.rearrange('q b f -> (b f) q', quantiles)
        x_support = einx.rearrange('b s f -> (b f) s', x_support).contiguous()
        x_query__ = einx.rearrange('b s f -> (b f) s', x_query__).contiguous()

        bucketize = torch.vmap(torch.bucketize, in_dims=(0, 0), out_dims=0)
        x_support = bucketize(x_support, quantiles).float() 
        x_query__ = bucketize(x_query__, quantiles).float()
        x_support = einx.rearrange('(b f) s -> b s f', x_support, b=batch_size).contiguous()
        x_query__ = einx.rearrange('(b f) s -> b s f', x_query__, b=batch_size).contiguous()
        x_support = x_support / seq_len[:, None, None] 
        x_query__ = x_query__ / seq_len[:, None, None]

        # Make sure that the padding is not used in the calculation of the mean
        x_support[padding_mask] = 0
        x_support_mean = einx.sum('b [s] f', x_support, keepdims=True) / seq_len[:, None, None]

        x_support = x_support - x_support_mean
        x_query__ = x_query__ - x_support_mean

        # Make sure that the padding is not used in the calculation of the variance
        x_support[padding_mask] = 0
        x_support_var = einx.sum('b [s] f', x_support**2, keepdims=True) / seq_len[:, None, None]

        x_support = x_support / x_support_var.sqrt()
        x_query__ = x_query__ / x_support_var.sqrt()

        # In case an x_support feature column contains one unique feature, set the feature to zero 
        x_support = torch.where(x_support_var == 0, 0, x_support)
        x_query__ = torch.where(x_support_var == 0, 0, x_query__)

        if self.feature_count_scaling:
            x_support = x_support * self.n_features / feature_count[:, None, None]
            x_query__ = x_query__ * self.n_features / feature_count[:, None, None]

        return x_support, x_query__


class FoundationEmbeddingYFloat(torch.nn.Module):

    def __init__(
            self,
            dim: int,
        ) -> None:
        
        super().__init__()

        self.dim = dim

        self.y_embedding = nn.Linear(1, dim)


    def forward(self, y_support: torch.Tensor, n_obs_query: int) -> tuple[torch.Tensor, torch.Tensor]:

        batch_size = y_support.shape[0]

        y_support = y_support.type(torch.float32)
        y_support = einops.rearrange(y_support, 'b n -> b n 1')

        y_support = self.y_embedding(y_support)
        y_query = torch.zeros((batch_size, n_obs_query, self.dim), device=y_support.device, dtype=torch.float32)

        return y_support, y_query
    


class FoundationEmbeddingYInteger(torch.nn.Module):

    def __init__(
            self,
            n_classes: int, 
            dim: int,
        ) -> None:
        
        super().__init__()

        self.n_classes = n_classes
        self.dim = dim

        self.y_embedding = nn.Embedding(n_classes, dim)
        self.y_mask = nn.Embedding(1, dim) # masking is also modeled as a separate class


    def forward(self, y_support: torch.Tensor, n_obs_query: int) -> tuple[torch.Tensor, torch.Tensor]:

        batch_size = y_support.shape[0]
        n_obs_support = y_support.shape[1]


        y_support_pad = y_support == -100

        y_sup = torch.zeros((batch_size, n_obs_support, self.dim), device=y_support.device, dtype=torch.float32)
        y_sup[~y_support_pad] = self.y_embedding( y_support[~y_support_pad]       )

        y_query = torch.zeros((batch_size, n_obs_query), device=y_support.device, dtype=torch.int64)
        y_query = self.y_mask(y_query)

        return y_sup, y_query
    

class FoundationObservationEmbedding(torch.nn.Module):

    def __init__(self, dim: int) -> None:
        
        super().__init__()

        self.dim = dim
        self.max_dim = 2**16
        self.embedding = nn.Embedding(self.max_dim, dim)

    
    def forward(self, batch_size: int, n_obs: int) -> torch.Tensor:

        assert n_obs <= self.max_dim, f'Number of observations is too large. Max is {self.max_dim}, got {n_obs}'

        # Take a random embedding from the pool of embeddings 
        weights = torch.ones((batch_size, self.max_dim), dtype=torch.float32, device=self.embedding.weight.device)
        indices = torch.multinomial(weights, num_samples=n_obs, replacement=False)
        x = self.embedding(indices)
        
        return x