from dataclasses import dataclass
from typing import Literal, Optional

from ruamel.yaml import YAML, yaml_object

from models import ModelConfig
from ._loss_functions import IMLEKargerConfig


@yaml_object(YAML())
@dataclass()
class TrainingConfig:
    """
    Fields:

    - `dataset_name`: The name of the dataset to train on.
                      All datasets should be stored in `data/`, in a file named `dataset_name + ".pt"`
    - `model_config`: Specifies the model to train.
    - `loss_function`: The loss function to train with.
                       Must be `"supervised_binary_cross_entropy"`, `"supervised_imle"`, `"self_supervised_imle"`,
                       `"direct_gradient"`, `"tsp_direct_gradient"`, or `"reinforce"`.
                       If this is set to `"supervised_binary_cross_entropy"` or `"supervised_imle"`, all graphs in the
                       dataset must contain a ground truth minimum cut as `graph.y`.
    - `imle_karger_config`: Configuration for I-MLE gradient estimation.
                            Cannot be `None` if `loss_function` is `"supervised_imle"` or `"self_supervised_imle"`.
                            Ignored otherwise.
    - `batch_size`: The number of graphs in one mini-batch.
    - `initial_lr`: The initial learning rate.
    - `lr_scheduler_patience`: The number of epochs without improvement that the learning rate scheduler waits for
                               before it reduces the learning rate.
    - `num_epochs`: The maximum number of epochs to train for, in case the training is not stopped by other means.
    - `device`: The model will be trained using device `torch.device(device)`.
    - `checkpoint_interval`: A training checkpoint, including the model parameters, is stored every
                             `checkpoint_interval` evaluations.
                             Please note that changing `num_evaluations_per_epoch` also changes how often checkpoints
                             are saved.
                             Checkpoints are stored in `logs/`.
    - `num_evaluations_per_epoch`: How often the model is evaluated during each epoch.
    """

    dataset_name: str
    model_config: ModelConfig
    loss_function: Literal[
        "supervised_binary_cross_entropy", "supervised_imle", "self_supervised_imle", "direct_gradient",
        "tsp_direct_gradient", "reinforce",
    ]
    imle_karger_config: Optional[IMLEKargerConfig]
    batch_size: int
    initial_lr: float
    lr_scheduler_patience: int
    num_epochs: int
    device: str
    checkpoint_interval: int = 1
    num_evaluations_per_epoch: int = 1
