import torch
import torch.nn as nn
import torch.nn.functional as F


class PoolingModule(torch.nn.Module):

    def __init__(self, pooling, cfg:None, hidden_dim=None, device=None, dtype=None):
        super().__init__()
        self.cfg = cfg
        accepted_pooling = ['sum', 'avg', 'last_token', 'attention_pool', 'cls', 'weighted_avg', "max"]

        assert pooling in accepted_pooling, f"Pooling must be one of {accepted_pooling}, got {pooling}"
        # Assert that only bert-base backbone can have cls pooling
        if pooling == 'cls':
            assert cfg.backbone in ['bert-base', 'bert-finetuned'], "Only bert-base backbone can have cls pooling"

        self.pooling = pooling
        self.attention_pooling = AttentionPooling(d_model=hidden_dim, dtype=dtype if dtype is not None else torch.bfloat16) if pooling == 'attention_pool' else None

        
        if pooling == "weighted_avg":
            wa_n_tokens = cfg.datasets.max_length
            if cfg.backbone in ['bert-base', 'bert-finetuned']:
                # For BERT, the cls token is removed.
                wa_n_tokens = cfg.datasets.max_length - 1
            elif cfg.datasets.name == "next_token_tinystories":
                # For next token prediction, one token is removed (and shifted) to create the
                # labels for the next token prediction task.
                wa_n_tokens = cfg.datasets.max_length - 1
        
            self.weighted_average_pooling = WeightedAveragePooling(n_tokens=wa_n_tokens, cfg=self.cfg)
        else:
            self.weighted_average_pooling = None

    def forward(self, x, attention_mask=None):
        B, N, D = x.shape
        assert attention_mask is not None, "Attention mask must be provided for pooling"

        # For bert-base, CLS should be removed from the sequence, unless we are using CLS pooling
        if (self.pooling != 'cls') and (self.cfg.backbone == 'bert-base'):
            x = x[:, 1:, :]
            attention_mask = attention_mask[:, 1:]

        if self.pooling == 'sum':
            x = self.sum_pooling(x, attention_mask)
        elif self.pooling == 'avg':
            x = self.avg_pooling(x, attention_mask)
        elif self.pooling == 'last_token':
            x = self.last_token_pooling(x, attention_mask, self.cfg.padding_side)
        elif self.pooling == 'attention_pool':
            x = self.attention_pooling(x, attention_mask)
        elif self.pooling == 'cls':
            x = self.cls_pooling(x, attention_mask)

        elif self.pooling == 'weighted_avg':
            x = self.weighted_average_pooling(x, attention_mask)
        elif self.pooling == 'max':
            x = self.max_pooling(x, attention_mask)
        
        else:
            raise ValueError(f"Pooling method {self.pooling} not recognized.")


        return x

    # Static functions for pooling methods. Each must take x, a tensor of shape (B, N, D) and attention_mask, a tensor of shape (B, N)
    @staticmethod
    def sum_pooling(x, attention_mask):
        x = x * attention_mask.unsqueeze(-1)
        x = x.sum(dim=1)
        return x
    
    @staticmethod
    def avg_pooling(x, attention_mask):
        x = x * attention_mask.unsqueeze(-1)
        x = x.sum(dim=1)
        x = x / attention_mask.sum(dim=1).unsqueeze(-1)
        return x
    
    @staticmethod
    def max_pooling(x, attention_mask):
        mask = attention_mask.unsqueeze(-1).bool()
        x = x.masked_fill(~mask, float('-inf'))
        x, _ = x.max(dim=1)
        return x
    
    @staticmethod
    def max_pooling(x, attention_mask):
        mask = attention_mask.unsqueeze(-1).bool()
        x = x.masked_fill(~mask, float('-inf'))
        x, _ = x.max(dim=1)
        return x

    @staticmethod
    def get_last_token_idx(attention_mask, padding_side):
        if padding_side == "right":
            return (attention_mask.sum(dim=1) - 1).unsqueeze(-1)
        elif padding_side == "left":
            return -1
        else:
            raise ValueError(f"Invalid padding side: {padding_side}. Must be 'right' or 'left'.")

    @staticmethod
    def last_token_pooling(x, attention_mask, padding_side):
        if padding_side == "right":
            last_token_indices = (attention_mask.sum(dim=1) - 1).unsqueeze(-1).unsqueeze(-1).expand(-1, -1, x.shape[-1])
            x = x.gather(1, last_token_indices).squeeze(1)
            return x

        elif padding_side == "left":
            assert all(attention_mask[:, -1] == 1), "Last token must be valid for last token pooling"
            return x[:, -1, :]
        else:
            raise ValueError(f"Invalid padding side: {padding_side}. Must be 'right' or 'left'.")

    def cls_pooling(self, x, attention_mask):
        x = x[:, 0, :]
        return x

class AttentionPooling(torch.nn.Module):
    def __init__(self, d_model, num_queries=1, dtype=None):
        super().__init__()

        # Projection matrices
        self.w_q = torch.nn.Linear(d_model, d_model, bias=False, dtype=dtype)
        self.w_k = torch.nn.Linear(d_model, d_model, bias=False, dtype=dtype)
        self.w_v = torch.nn.Linear(d_model, d_model, bias=False, dtype=dtype)

        # Learnable queries
        #NOTE: Initialize with zeros: start from uniform attention (i.e. average of all tokens)
        self.q = torch.nn.Parameter(torch.zeros(num_queries, d_model, dtype=dtype))

        self.softmax = torch.nn.Softmax(dim=-1)
        self.scale = 1 / (d_model ** 0.5)
        

    def forward(self, x, attention_mask=None):
        """
        x: Tensor of shape (B, N, D)  [Batch, Sequence Length, Feature Dim]
        attention_mask: Tensor of shape (B, N)  where 1 indicates valid tokens 
                        and 0 indicates padded tokens. If None, we assume all valid.
        
        Returns: Tensor of shape (B, D) [Batch, Feature Dim]
        """

        B, N, D = x.shape

        # Apply projections
        q_i = self.w_q(self.q).unsqueeze(0).repeat(B, 1, 1) # (1, num_queries, D)
        k = self.w_k(x) # (B, N, D)
        v = self.w_v(x) # (B, N, D)

        # Compute attention scores
        scores = torch.bmm(q_i, k.transpose(1, 2)) * self.scale
        if attention_mask is not None:
            scores = scores.masked_fill(~attention_mask.unsqueeze(1).to(bool), float('-inf'))

        scores = self.softmax(scores)

        # Compute the weighted sum
        x = torch.bmm(scores, v).squeeze(1)

        return x


class WeightedAveragePooling(nn.Module):
    def __init__(self, n_tokens, cfg):
        super().__init__()
        self.w = nn.Parameter(torch.ones(n_tokens))
        self.cfg = cfg

    def forward(self, x, attention_mask=None):
        """
        Aggregates using learnt weights. During forward pass the weights are normalized
        with the attention mask in mind (i.e. only weights of actual token positions are used).
        x: Tensor of shape (B, N, D)  [Batch, Sequence Length, Feature Dim]
        attention_mask: Tensor of shape (B, N)  where 1 indicates valid tokens 
                        and 0 indicates padded tokens. If None, we assume all valid.
        
        Returns: Tensor of shape (B, D) [Batch, Feature Dim]
        """
        assert self.cfg.padding_side == "left", "Weighted average pooling only supports left padding"
        weights = self.w  # Shape: (N,)
        weights = weights.unsqueeze(0).unsqueeze(-1)  # Shape: (1, N, 1)
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(-1)
            weights = weights.masked_fill(~attention_mask.to(bool), float("-inf"))  # Shape: (B, N, 1)
        
        weights = F.softmax(weights, dim=1)
        # For left padded input, nan values are possible for causal mask, we replace them with 0
        weights = weights.masked_fill(torch.isnan(weights), 0.0)
        return (x * weights.to(x.dtype)).sum(dim=1)

        
class CausalAggregator(torch.nn.Module):
    """
    Aggregates the embeddings across time in a causal manner.
    i.e. at each time step t, the aggregated embedding is some function
    of all embeddings up to that point.
    
    Currently supports:
      - 'sum' => cumsum over time
      - 'avg' => cumsum over time, then divide by (# of valid tokens)
      - 'max' => at each t, run max pool on x[:, :t+1] 
                            to get a single embedding for that position
      - 'attention_pool' => at each t, run attention on x[:, :t+1] 
                            to get a single embedding for that position
    """
    def __init__(self, agg_method, attn_pool=None, weighted_avg_pool=None):
        """
        Parameters
        ----------
        agg_method : str
            One of ['sum', 'avg', 'attention_pool', 'max'].
        attn_pool : callable or None
            If `agg_method='attention_pool'`, this should be an instance of 
            AttentionPooling or a compatible module that does attention 
            pooling over dimension=1.
        """
        super().__init__()
        self.agg_method = agg_method
        self.attn_pool = attn_pool
        self.weighted_avg_pool = weighted_avg_pool

        # Some basic checks
        valid_methods = ['sum', 'avg', 'attention_pool', 'weighted_avg', 'max']

        assert agg_method in valid_methods, \
            f"CausalAggregator agg_method must be in {valid_methods}"


    def forward(self, x, padding_mask):
        """
        x: Tensor of shape (B, N, D)
        padding_mask: Tensor of shape (B, N) with 1 for valid tokens and 0 for padded.
        
        Returns a Tensor of shape (B, N, D). 
        """
        B, N, D = x.shape
        # Expand the mask to shape (B, N, D) when we do cumsum in sum/avg
        # but keep (B, N) or (B, N) -> (B, 1, N) for attention masking.
        
        if self.agg_method == 'sum':
            # We do a cumsum along dim=1, but also we only want to sum valid tokens.
            # A quick way is to zero out padded positions, then cumsum.
            x_masked = x * padding_mask.unsqueeze(-1)  # (B, N, D)
            out = torch.cumsum(x_masked, dim=1)        # (B, N, D)
            return out

        elif self.agg_method == 'avg':
            # We'll do the same cumsum but also divide by the # of valid tokens so far.
            x_masked = x * padding_mask.unsqueeze(-1)  # (B, N, D)
            cum_x = torch.cumsum(x_masked, dim=1)       # (B, N, D)
            # We want to divide each time step i by the cumsum of 1's up to i
            # i.e. cumsum of padding_mask
            cum_counts = torch.cumsum(padding_mask, dim=1)  # (B, N)
            # Avoid dividing by zero if any sequence positions are padded from the start
            # (but typically it shouldn't matter if the mask is well-formed).
            out = cum_x / (cum_counts.unsqueeze(-1) + 1e-8)
            return out

        elif self.agg_method == 'attention_pool':
            # We'll produce out of shape (B, N, D), 
            # where out[:, t, :] = attention over x[:, :t+1, :].
            out = torch.zeros_like(x)  # shape (B, N, D)
            for t in range(N):
                partial_x = x[:, :t+1, :]           # (B, t+1, D)
                partial_mask = padding_mask[:, :t+1] # (B, t+1)
                # Now do attention pooling
                # This returns shape (B, D), after we squeeze the last dimension
                out_t = self.attn_pool(partial_x, partial_mask)
                # Insert in out
                out[:, t, :] = out_t
            return out
        
        elif self.agg_method == 'max':
            out = torch.zeros_like(x)
            for t in range(N):
                # Set the future context to 0 in the mask to make
                # weighted average not take these into account.
                mask = padding_mask.clone()
                mask[:, t+1:] = 0
                out_t = PoolingModule.max_pooling(x, mask)
                out[:, t, :] = out_t
            return out

        elif self.agg_method == 'last_token':
            # Just return the last token for each sequence
            out = x[:, -1, :]
            return out

        elif self.agg_method == 'weighted_avg':
            out = torch.zeros_like(x)
            for t in range(N):
                # Set the future context to 0 in the mask to make
                # weighted average not take these into account.
                mask = padding_mask.clone()
                mask[:, t+1:] = 0
                out_t = self.weighted_avg_pool(x, mask)
                out[:, t, :] = out_t
            return out
        else:
            raise ValueError(f"Invalid aggregation method: {self.agg_method}")