from dataclasses import dataclass
from typing import Literal


@dataclass
class InterpreterConfig:
    """Configuration class for InterpreterModel.

    This configuration defines all hyperparameters for the interpreter model
    that explains which input sentences are most important for generating
    specific target output sentences.

    Attributes:
        # INTERP-1: Input and Target Embedding Configuration
        target_embedding (str): How to embed target sentence.
            - "isolated": Embed target sentence alone
            - "contextualized": Embed target within full output context
        encoder_sharing_strategy (str): How encoders are shared with EBM.
            - "shared": Use same encoder instances as EBM (weight sharing)
            - "cloned": Copy encoder weights but train independently
        freeze_encoder_in_interp (bool): Whether to freeze encoder weights
            during interpreter training.

        # INTERP-2: Cross-Attention Configuration
        cross_attention_layers (int): Number of cross-attention layers.
        attention_heads (int): Number of attention heads in cross-attention.
        dropout_rate (float): Dropout rate for cross-attention layers.

        # INTERP-3: Pooling and Scoring Configuration
        # INTERP-3: Scoring Configuration
        mlp_hidden_dim (int): Hidden dimension for scoring MLP.
        mlp_layers (int): Number of layers in scoring MLP.
            - 1: Single linear layer (sentence_embedding -> score)
            - >1: Multi-layer MLP with hidden layers

        # INTERP-4: Masking Configuration
        masking_type (str): Main type of masking to apply.
            - "hard": Binary masking (keep/remove sentences completely)
            - "soft": Continuous masking (scale sentence embeddings)
        hard_mask_method (str): Method for hard masking.
            - "top_k": Keep only top-k highest scoring sentences
            - "threshold": Keep sentences above threshold
        soft_mask_method (str): Method for soft masking.
            - "multiply": Scale embeddings by importance scores
            - "interpolate": Interpolate between original and zero embeddings
        top_k (int): Number of sentences to keep for top_k masking.
        threshold (float): Threshold for threshold masking.
        comparison_type (str): What to compare in energy evaluation.
            - "full": Compare (input, output) vs (masked_input, output)
            - "targeted": Compare (input, target) vs (masked_input, target)

        # INTERP-5: Training Configuration
        loss_type (str): Type of loss function to use.
            - "contrastive": Uses InfoNCE loss to distinguish original vs masked inputs
            - "regression": Predicts energy increase based on importance scores
            - "regularized": InfoNCE loss with sparsity regularization
        regression_loss_type (str): Subtype for regression loss.
            - "mse": Mean Squared Error
            - "mae": Mean Absolute Error (L1 loss)
            - "huber": Huber loss (smooth L1)
        energy_margin (float): Margin for contrastive loss (backward compatibility).
        infonce_temperature (float): Temperature parameter for InfoNCE loss scaling.

        # General Configuration
        activation_fn (str): Activation function for MLPs.

    """

    # INTERP-1: Input and Target Embedding
    target_embedding: Literal["isolated", "contextualized"] = "contextualized"
    encoder_sharing_strategy: Literal["shared", "cloned"] = "cloned"
    freeze_encoder_in_interp: bool = False

    # INTERP-2: Cross-Attention
    cross_attention_layers: int = 1
    attention_heads: int = 8
    dropout_rate: float = 0.1

    # INTERP-3: Pooling and Scoring - Uses MLP with configurable layers
    mlp_hidden_dim: int = 256
    mlp_layers: int = 1  # 1 = single linear layer, >1 = multi-layer MLP

    # INTERP-4: Masking
    masking_type: Literal["hard", "soft"] = "hard"
    hard_mask_method: Literal["top_k", "threshold"] = "top_k"
    soft_mask_method: Literal["multiply", "interpolate"] = "multiply"
    top_k: int = 8
    threshold: float = 0.5
    comparison_type: Literal["full", "targeted"] = "targeted"

    # Gumbel Softmax parameters
    gumbel_temperature: float = 1.0
    gumbel_hard: bool = False
    gumbel_k: int = 4  # Number of times to sample Gumbel noise for robust top-k selection

    # INTERP-5: Training
    loss_type: Literal["contrastive", "regression", "regularized"] = "contrastive"
    regression_loss_type: Literal["mse", "mae", "huber"] = "mse"
    energy_margin: float = 1.0
    infonce_temperature: float = 0.1

    # General
    activation_fn: Literal["ReLU", "GELU"] = "GELU"
    softmax_type: Literal["gumbel", "normal"] = "gumbel"
