"""Defines the configuration dataclass for the EnergyModel."""

from dataclasses import dataclass
from typing import Literal


@dataclass
class EBMConfig:
    """Configuration class for the EnergyModel architecture.

    This dataclass holds all hyperparameters that define the structure and
    behavior of the EnergyModel, from the underlying text encoder to the
    final energy head MLP.

    Attributes:
        text_encoder_model_type: The type of sentence embedding model.
            "SentenceBERT" uses a standard sentence-transformer model.
            "BERT_CLS" uses the [CLS] token from a standard BERT model.
        placeholder_type: The type of placeholder embedding for padding sentences.
            "learnable" creates a trainable tensor. "static" uses zeros.
        placeholder_init_type: The initialization method for learnable placeholders.
            "random" uses a standard normal distribution. "zero" uses zeros.
        n_sentences: The fixed number of sentences to normalize all texts to by
            padding with placeholders or truncating.
        d_model: The embedding dimension of the text encoder and attention layers.
        self_attention_n_layers: The number of self-attention encoder layers.
        cross_attention_n_layers: The number of cross-attention encoder layers.
        attention_n_heads: The number of heads in the multi-head attention mechanisms.
        dropout_rate: The dropout rate used for regularization in attention and MLP layers.
        energy_head_mlp_layers: The number of layers in the final MLP energy head.
        energy_head_hidden_factor: A multiplier for the hidden dimension size in the
            energy head MLP (e.g., a factor of 2 means hidden_dim = 2 * d_model).
        energy_head_pooling_type: The pooling strategy to aggregate sentence embeddings
            before the energy head. "attention" uses attention pooling, "flatten"
            concatenates them.
        activation_fn: The activation function for the MLP layers ("ReLU" or "GELU").
    """

    # Sentence Encoder model
    text_encoder_model_type: Literal["SentenceBERT", "BERT_CLS"] = "SentenceBERT"

    # Placeholder settings
    placeholder_type: Literal["learnable", "static"] = "learnable"
    placeholder_init_type: Literal["random", "zero"] = "random"

    # General settings
    n_sentences: int = 16

    # Attention settings
    d_model: int = 768
    self_attention_n_layers: int = 2
    cross_attention_n_layers: int = 2
    attention_n_heads: int = 8
    dropout_rate: float = 0.1

    # Energy head settings
    energy_head_mlp_layers: int = 2
    energy_head_hidden_factor: int = 2
    energy_head_pooling_type: Literal["flatten", "attention"] = "attention"
    activation_fn: Literal["ReLU", "GELU"] = "GELU"
