from dataclasses import (
    dataclass,
    field,
)
from transformers import PretrainedConfig


# Uncomment the calibration datasets you want to use.
def get_default_calibration_datasets():
    return {
        'alpaca': {
            'data_number': 82,
            'batch_size': 1,
        },
        # 'c4': {
        #     'data_number': 35,
        #     'batch_size': 1,
        # },
        'openbookqa': {
            'data_number': 385,
            'batch_size': 1,
        },
        'piqa': {
            'data_number': 489,
            'batch_size': 1,
        },
        'wikitext2': {
            'data_number': 256,
            'batch_size': 1,
        },
    }


@dataclass
class MoLoSConfig(PretrainedConfig):
    """ The configuration class for the MoLoS model.
    """

    # Calibration dataset.
    ## The number of calibration datasets must be the same as the number of experts (ex_num).
    calibration_datasets: dict[str, int] = \
        field(default_factory=get_default_calibration_datasets)

    ## This parameter indicates the sequence length used for whitening with the calibration dataset.
    sequence_length: int = 2048

    # Mixture of Experts.
    jitter_noise: float = 0.025
    output_router_logits: bool = True
    router_aux_loss_coef: float = 0.5
    router_z_loss_coef: float = 0

    ## The type of the routing mechanism.
    ## This parameter can be `sequence` or `token`.
    ## - If select_type is `sequence`, selected_ex_num must be 1 and output_router_logits must be False.
    ## - If select_type is `token`, selected_ex_num can be any number from 1 to ex_num.
    select_type: str = 'token'

    ## `ex` indicates expert.
    ex_num: int = 4
    selected_ex_num: int = 2
    ex_params_ratio: float = 0.25

    ## Only for debugging.
    ## If this parameter is set to a value different from -1, you should set output_router_logits to False.
    chosen_ex_idx: int = -1
