import torch
import torch.nn as nn
import torch.nn.functional as F

class Pooling(nn.Module):
    """
    A self-contained pooling module that supports multiple pooling methods.
    
    Supported pooling methods:
        - 'mean'
        - 'mean_normalized'
        - 'sum'
        - 'sum_normalized'
        - 'last'
        - 'last_normalized'
        - 'max'
        - 'max_normalized'
        - 'attention'
        - 'attention_normalized'
        - 'weighted_avg'
    
    Args:
        pooling_method (str): The pooling method to use.
        d_model (int, optional): Feature dimension. Required if pooling_method is attention-based.
        num_queries (int, optional): Number of queries for attention pooling. Defaults to 1.
        agg_dim (int): Dimension to aggregate over. Defaults to -2.
    
    """
    def __init__(
                self, 
                pooling_method: str = 'mean', 
                d_model: int = None, 
                num_queries: int = 1,
                agg_dim: int = -2,
                num_patches: int = None,
            ):
        super(Pooling, self).__init__()
        self.pooling_method = pooling_method
        self.agg_dim = agg_dim
        
        # If attention-based pooling is selected, instantiate the AttentionPooling submodule.
        if self.pooling_method in ['attention', 'attention_normalized']:
            if d_model is None:
                raise ValueError("d_model must be provided for attention pooling.")
            self.attention_pooling = AttentionPooling(d_model, num_queries=num_queries)
        elif self.pooling_method == 'weighted_average':
            if num_patches is None:
                raise ValueError("num_patches must be provided for weighted average pooling.")
            self.weighted_avg_pooling = WeightedAveragePooling(n_tokens=num_patches)

    def forward(self, x: torch.Tensor, patch_mask: torch.Tensor = None) -> torch.Tensor:
        # TODO: modify this code so that it can also works with a patch_based mask.
        """
        Apply pooling to the input tensor.
        
        Args:
            x (torch.Tensor): Input tensor of shape [batch_size, patch_len, d_model].
            patch_mask (torch.Tensor, optional): Mask indicating valid patches with shape [batch_size, patch_len].
                This mask is also passed to the attention pooling routines if provided.
        
        Returns:
            torch.Tensor: Pooled output of shape [batch_size, d_model].
        """
        # For attention-based pooling, delegate to the attention submodule.
        if self.pooling_method in ['attention', 'attention_normalized']:
            x = self.attention_pooling(x)
            if self.pooling_method == 'attention_normalized':
                x = F.normalize(x, dim=1, p=2, eps=1e-12)
            return x

        if self.pooling_method == 'weighted_average':
            x = self.weighted_avg_pooling(x)
            return x

        # For the other pooling methods:
        if self.pooling_method == 'mean':
            x = torch.mean(x, dim=self.agg_dim)
        elif self.pooling_method == 'mean_normalized':
            x = torch.mean(x, dim=self.agg_dim)
            x = F.normalize(x, dim=-1, p=2, eps=1e-12)
        elif self.pooling_method == 'sum':
            x = torch.sum(x, dim=self.agg_dim)
        elif self.pooling_method == 'sum_normalized':
            x = torch.sum(x, dim=self.agg_dim)
            x = F.normalize(x, dim=-1, p=2, eps=1e-12)
        elif self.pooling_method == 'last':
            x = x.select(dim=self.agg_dim, index=-1)
        elif self.pooling_method == 'last_normalized':
            x = x.select(dim=self.agg_dim, index=-1)
            x = F.normalize(x, dim=-1, p=2, eps=1e-12)
        elif self.pooling_method == 'max':
            x, _ = torch.max(x, dim=self.agg_dim)
        elif self.pooling_method == 'max_normalized':
            x, _ = torch.max(x, dim=self.agg_dim)
            x = F.normalize(x, dim=-1, p=2, eps=1e-12)
        else:
            raise ValueError(f"Pooling method '{self.pooling_method}' not implemented.")
        
        return x


class AttentionPooling(torch.nn.Module):
    """Attention pooling for both [batch_size, seq_len, d_model]
    and [batch_size, num_variables, seq_len, d_model] inputs.
    """
    def __init__(self, d_model, num_queries=1):
        super().__init__()

        self.w_q = torch.nn.Linear(d_model, d_model, bias=False)
        self.w_k = torch.nn.Linear(d_model, d_model, bias=False)
        self.w_v = torch.nn.Linear(d_model, d_model, bias=False)

        self.q = torch.nn.Parameter(torch.zeros(num_queries, d_model))

        self.softmax = torch.nn.Softmax(dim=-1)
        self.scale = 1 / (d_model ** 0.5)

    def forward(self, x, attention_mask=None, return_scores=False):
        """
        x: Tensor of shape (B, N, D) or (B, V, N, D)
        attention_mask: Tensor of shape (B, N) or (B, V, N)
        Returns:
            Pooled output of shape (B, D) or (B, V, D)
        """

        if x.dim() == 3:
            # [B, N, D]
            B, N, D = x.shape
            q_i = self.w_q(self.q).unsqueeze(0).repeat(B, 1, 1)     # [B, Q, D]
            k = self.w_k(x)                                          # [B, N, D]
            v = self.w_v(x)                                          # [B, N, D]
            scores = torch.bmm(q_i, k.transpose(1, 2)) * self.scale  # [B, Q, N]

            if attention_mask is not None:
                scores = scores.masked_fill(~attention_mask.unsqueeze(1).to(bool), float('-inf'))

            attn_weights = self.softmax(scores)                      # [B, Q, N]
            x = torch.bmm(attn_weights, v).squeeze(1)                # [B, D]
            return (x, scores) if return_scores else x

        elif x.dim() == 4:
            # [B, V, N, D]
            B, V, N, D = x.shape
            x = x.reshape(B * V, N, D)
            if attention_mask is not None:
                attention_mask = attention_mask.reshape(B * V, N)

            # Forward as usual
            q_i = self.w_q(self.q).unsqueeze(0).repeat(B * V, 1, 1)  # [B*V, Q, D]
            k = self.w_k(x)                                          # [B*V, N, D]
            v = self.w_v(x)                                          # [B*V, N, D]
            scores = torch.bmm(q_i, k.transpose(1, 2)) * self.scale  # [B*V, Q, N]

            if attention_mask is not None:
                scores = scores.masked_fill(~attention_mask.unsqueeze(1).to(bool), float('-inf'))

            attn_weights = self.softmax(scores)                      # [B*V, Q, N]
            x = torch.bmm(attn_weights, v).squeeze(1)                # [B*V, D]
            x = x.reshape(B, V, D)
            scores = scores.reshape(B, V, *scores.shape[1:])

            return (x, scores) if return_scores else x

        else:
            raise ValueError(f"Unsupported input shape: {x.shape}")


class WeightedAveragePooling(nn.Module):
    """
    Weighted average pooling with softmax normalization over token positions.
    Handles both univariate (B, N, D) and multivariate (B, V, N, D) inputs.
    Each token position has a learnable weight, shared across batch (and variables).
    """
    def __init__(self, n_tokens, init_strategy=None, dtype=torch.float32, **init_kwargs):
        """
        Args:
            n_tokens: number of token positions to weight
            init_strategy: callable that returns a tensor of shape [n_tokens]
            init_kwargs: keyword arguments to pass to init_strategy (e.g., seed, index, base)
        """
        super().__init__()
        if init_strategy is None:
            init_strategy = WeightInitStrategy.uniform
        
        # Generate initial weights with RNG-safe strategy
        weight_init = init_strategy(n_tokens, dtype=dtype, **init_kwargs)
        self.w = nn.Parameter(weight_init, requires_grad=True)

    def forward(self, x):
        """
        Args:
            x: Tensor of shape (B, N, D) or (B, V, N, D)
        Returns:
            Tensor of shape (B, D) or (B, V, D)
        """
        if x.dim() == 3:
            # Univariate case: (B, N, D)
            B, N, D = x.shape
            w = self.w.unsqueeze(0).expand(B, -1)  # (B, N)
            w_normalized = torch.softmax(w, dim=1)  # (B, N)
            output = (x * w_normalized.unsqueeze(-1)).sum(dim=1)  # (B, D)
            return output
        elif x.dim() == 4:
            # Multivariate case: (B, V, N, D)
            B, V, N, D = x.shape
            w = self.w.unsqueeze(0).unsqueeze(0).expand(B, V, -1)  # (B, V, N)
            w_normalized = torch.softmax(w, dim=2)  # (B, V, N)
            output = (x * w_normalized.unsqueeze(-1)).sum(dim=2)  # (B, V, D)
            return output
        else:
            raise ValueError("Input tensor must have shape (B, N, D) or (B, V, N, D)")


class WeightInitStrategy:
    @staticmethod
    def uniform(n_tokens, dtype=torch.float32, seed=None):
        """Uniform initialization (mean pooling)"""
        return torch.ones(n_tokens, dtype=dtype) / n_tokens

    @staticmethod
    def bias_to_first(n_tokens, dtype=torch.float32, seed=None):
        """Linearly decaying weights favoring early tokens"""
        weights = torch.linspace(1.0, 0.1, steps=n_tokens, dtype=dtype)
        return weights / weights.sum()

    @staticmethod
    def bias_to_last(n_tokens, dtype=torch.float32, seed=None):
        """Linearly increasing weights favoring later tokens"""
        weights = torch.linspace(0.1, 1.0, steps=n_tokens, dtype=dtype)
        return weights / weights.sum()

    @staticmethod
    def random(n_tokens, dtype=torch.float32, seed=None):
        """Random weights with RNG isolation"""
        g = torch.Generator()
        if seed is not None:
            g.manual_seed(seed)
        weights = torch.rand(n_tokens, dtype=dtype, generator=g)
        return weights / weights.sum()

    @staticmethod
    def single_token_focus(n_tokens, index=0, dtype=torch.float32, seed=None):
        """1-hot vector focusing on a specific token index"""
        weights = torch.zeros(n_tokens, dtype=dtype)
        weights[index] = 1.0
        return weights

    @staticmethod
    def sum_pooling(n_tokens, dtype=torch.float32, seed=None):
        """Unnormalized equal weights (mimic sum pooling)"""
        return torch.ones(n_tokens, dtype=dtype)

    @staticmethod
    def exp_decay(n_tokens, base=0.9, dtype=torch.float32, seed=None):
        """Exponential decay — favors early tokens"""
        weights = torch.tensor([base**i for i in range(n_tokens)], dtype=dtype)
        return weights / weights.sum()

    @staticmethod
    def exp_growth(n_tokens, base=1.1, dtype=torch.float32, seed=None):
        """Exponential growth — favors late tokens"""
        weights = torch.tensor([base**i for i in reversed(range(n_tokens))], dtype=dtype)
        return weights / weights.sum()