from dataclasses import dataclass
from typing import Literal

from ruamel.yaml import YAML, yaml_object


@yaml_object(YAML())
@dataclass()
class IMLEKargerConfig:
    """
    Fields:

    - `num_karger_runs_per_noise_sample`: The number of times Karger-Stein is run on each set of perturbed steering
                                          weights.
    - `num_noise_samples`: The number of perturbed sets of steering weights to generate.
    - `input_noise_temperature`: Scaling factor for the noise used during the forward pass.
    - `target_noise_temperature`: Scaling factor for the noise used during the backward pass.
    - `sog_noise_k`: The parameter kappa to use for the Sum-of-Gamma noise distribution.
    - `sog_noise_iterations`: The number of iterations to use for the Sum-of-Gamma noise distribution.

    Note that Karger-Stein is run `num_noise_samples * num_karger_runs_per_noise_sample` times during the forward pass,
    and this many times again during the backward pass.
    """

    target_distribution: Literal["general_purpose", "karger"]
    num_karger_runs_per_noise_sample: int
    num_noise_samples: int
    input_noise_temperature: float
    target_noise_temperature: float
    sog_noise_k: float
    sog_noise_iterations: int
