import torch
import torch.nn as nn


class ForwardScorer(nn.Module):
    """
    Simple forward scorer that extracts a specific feature column.

    This is useful for testing or when you want to use a pre-computed score
    (like self-frequency) directly without learning.
    """
    def __init__(self, feature_index: int = 0):
        """
        Args:
            feature_index: which feature column to use as the score
        """
        super().__init__()
        self.feature_index = feature_index

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        """
        Extract a single feature column as scores.

        Args:
            features: [n, m] feature matrix

        Returns:
            scores: [n] scores (the specified feature column)
        """
        # Handle case where features might have extra dimensions
        if features.ndim == 3:
            features = features.squeeze(-1)

        # Extract the specified feature column
        scores = features[:, self.feature_index]

        # Ensure it's 1D
        if scores.ndim != 1:
            scores = scores.squeeze()

        return scores


class LogisticClaimScorer(nn.Module):
    """Simple logistic regression claim scorer."""
    def __init__(self, input_dim: int):
        super().__init__()
        self.linear = nn.Linear(input_dim, 1)
        
    def forward(self, features: torch.Tensor) -> torch.Tensor:
        if features.ndim == 3:
            features = features.squeeze(-1)
        return self.linear(features).squeeze(-1)


class WarmStartLogisticClaimScorer(nn.Module):
    """
    Logistic claim scorer with warm-start initialization.

    Instead of random initialization, this model initializes weights to exactly
    replicate a specified baseline feature (e.g., frequency-score). This helps
    prevent collapse at strict alpha levels (≤0.02) by starting in a "safe"
    region of parameter space where the baseline doesn't collapse.

    The model output at initialization is exactly: feature[align_dim]

    Args:
        input_dim: Number of input features
        align_dim: Feature index to align to at initialization (default: 0 for frequency-score)
        init_scale: Scale factor for initial weight (default: 1.0)
    """
    def __init__(self, input_dim: int, align_dim: int = 0, init_scale: float = 1.0):
        super().__init__()
        self.linear = nn.Linear(input_dim, 1)
        self.align_dim = align_dim

        # Initialize to replicate the baseline feature
        # All weights = 0, except weight[align_dim] = init_scale
        # Bias = 0
        with torch.no_grad():
            self.linear.weight.data.zero_()
            self.linear.bias.data.zero_()
            self.linear.weight.data[0, align_dim] = init_scale

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        if features.ndim == 3:
            features = features.squeeze(-1)
        return self.linear(features).squeeze(-1)


class MLPClaimScorer(nn.Module):
    """2-layer MLP with normalization and dropout."""
    def __init__(self, input_dim: int, hidden_dim: int = 64, dropout: float = 0.2):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        if features.ndim == 3:
            features = features.squeeze(-1)
        return self.layers(features).squeeze(-1)