import torch
import torch.nn as nn
from torch.utils.data import Dataset
from typing import Dict


class SequencePoolingHead(nn.Module):
    """A module that handles sequence pooling and prediction.

    This module takes sequence embeddings and:
    1. Pools them (currently using mean pooling)
    2. Passes them through a prediction head
    """

    def __init__(self, prediction_head: nn.Module, pooling_method: str = "mean"):
        super().__init__()
        self.prediction_head = prediction_head
        self.pooling_method = pooling_method

    def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
        """Forward pass that pools embeddings and applies prediction head.

        Args:
            embeddings: Tensor of shape [batch_size, seq_len, hidden_dim]

        Returns:
            Tensor of shape [batch_size, output_dim]
        """
        if self.pooling_method == "mean":
            pooled = embeddings.mean(dim=1)  # [batch_size, hidden_dim]
        else:
            raise ValueError(f"Unsupported pooling method: {self.pooling_method}")

        return self.prediction_head(pooled)


class FeatureTensorDataset(Dataset):
    """
    TensorDataset that returns dict batches compatible with model forward methods.

    This dataset takes pre-computed feature tensors and labels, and returns
    batches in the format expected by HAIPRModule.forward().
    """

    def __init__(self, features_dict: Dict[str, torch.Tensor], labels: torch.Tensor | None = None):
        """
        Initialize the dataset with pre-computed features.

        Args:
            features_dict: Dictionary of feature tensors
            labels: Labels tensor
        """
        self.features = features_dict
        self.labels = labels

        # Validate all tensors have same length
        if labels is not None:
            lengths = set(len(v) for v in features_dict.values())
            assert (
                len(lengths) == 1
            ), f"All feature tensors must have same length, got {lengths}"
            assert (
                len(labels) == list(lengths)[0]
            ), f"Labels length {len(labels)} must match features length {list(lengths)[0]}"

    def __len__(self):
        return len(next(iter(self.features.values())))

    def __getitem__(self, idx):
        """
        Get a single item from the dataset. Return format is compatible with HAIPRModule.forward().

        Returns:
            Dict with "inputs" (dict of features) and "labels" (tensor)
        """
        batch_features = {k: v[idx] for k, v in self.features.items()}
        if self.labels is not None:
            return {"inputs": batch_features, "labels": self.labels[idx]}
        else:
            return {"inputs": batch_features}
