from typing import Iterable
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, PackedSequence

class LSTMClassifier(nn.Module):
    
    def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int = 1) -> None:
        super(LSTMClassifier, self).__init__()

        # set lstm layer
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        # set output layer
        self.linear = nn.Linear(hidden_size, output_size)

        # Binary Classification needs a different loss function
        self.__is_binary = False if output_size > 1 else True
    
    def forward(self, x) -> torch.Tensor:
        """_summary_

        :param x: _description_
        :type x: _type_
        :return: _description_
        :rtype: torch.Tensor
        """
        if not isinstance(x, PackedSequence) and isinstance(x, tuple):
            x = pack_padded_sequence(x[0], x[1].cpu(), enforce_sorted=False, batch_first=True) # Necessary for pytorch DDP
        _, (out, _) = self.lstm(x)
        out = nn.ReLU()(out[0])
        out = self.linear(out)
        return out

class TransformerEncoderClassifier(nn.Module):

    def __init__(self, input_size: int, output_size: int, dim_feedforwards: int = 512, num_heads: int = 6, num_layers: int = 6, max_heads: int = 8,  pooling_strategy="mean") -> None:
        if input_size % num_heads != 0:
            print("Warning: input_size is not divisible by num_heads.")
            def find_devider(n):
                for i in range(4, n + 1, 2):
                    if n % i == 0:
                        return i if i <= max_heads else 2
                return 2
            num_heads = find_devider(input_size)
            print("Setting num_heads to {}.".format(num_heads))

        super(TransformerEncoderClassifier, self).__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=input_size,
            nhead=num_heads,
            dim_feedforward=dim_feedforwards,
            activation='gelu',
            batch_first=True
        )
        self.encoder_stack = nn.TransformerEncoder(
            encoder_layer=self.encoder_layer,
            num_layers=num_layers
        )
        self.linear = nn.Linear(input_size, output_size)
        self.pooling_strategy = pooling_strategy
        self.norm = nn.LayerNorm(input_size)
    
    def forward(self, x) -> torch.Tensor:
        # 
        tokens, lengths = x

        # Add positional encoding
        # tokens: (batch, seq_len, input_size)
        batch_size, seq_len, input_size = tokens.size()
        device = tokens.device

        # Create positional encodings
        position = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, seq_len)
        # Use sine-cosine positional encoding
        div_term = torch.exp(torch.arange(0, input_size, 2, device=device) * (-torch.log(torch.tensor(10000.0, device=device)) / input_size))
        pe = torch.zeros(batch_size, seq_len, input_size, device=device)
        pe[:, :, 0::2] = torch.sin(position.unsqueeze(-1) * div_term)
        pe[:, :, 1::2] = torch.cos(position.unsqueeze(-1) * div_term)
        tokens = tokens + pe  

        # compute padding_mask
        max_len = lengths.max().item()
        range_tensor = torch.arange(max_len, device=tokens.device).expand(len(lengths), max_len)
        padding_mask = range_tensor >= lengths.unsqueeze(1) if tokens.device == lengths.device else range_tensor >= lengths.unsqueeze(1).to(tokens.device)

        # forward pass
        out = self.encoder_stack(tokens, src_key_padding_mask=padding_mask)

        # Apply pooling strategy
        if self.pooling_strategy == "cls":
            # CLS Token Pooling
            out = out[:, 0, :]
        elif self.pooling_strategy == "mean":
            # Mean Pooling, safely avoiding division by zero
            out = out.masked_fill(padding_mask.unsqueeze(-1), 0.0)
            sum_out = out.sum(dim=1)
            count_non_padding = (~padding_mask).sum(dim=1).clamp(min=1).unsqueeze(-1)  # Avoid zero division
            out = sum_out / count_non_padding
        elif self.pooling_strategy == "max":
            # Max Pooling, safely handling all-padding case
            out = out.masked_fill(padding_mask.unsqueeze(-1), float('-inf'))
            out, _ = out.max(dim=1)
            # Replace -inf results (if all were padding) with zeros
            out = torch.where(torch.isfinite(out), out, torch.zeros_like(out))
        else:
            raise ValueError("Unknown pooling strategy: {}".format(self.pooling_strategy))
        
        out = self.norm(out)
        out = nn.functional.gelu(out)
        out = self.linear(out)

        return out