# postponed evaluation of annotations (use a class name of a class defined later as type hint)
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional

from ruamel.yaml import YAML, yaml_object

from ._graph_generators import GraphGeneratorConfig
from ._ground_truth_solvers import GroundTruthSolverConfig


@yaml_object(YAML())
@dataclass()
class DatasetConfig:
    """
    Fields:

    - `name`: The name of the dataset.
    - `num_train_graphs`: The total number of training graphs in the dataset.
                          The actual number of graphs may be slightly lower due to rounding error.
    - `num_val_graphs`: The total number of validation graphs in the dataset.
                        The actual number of graphs may be slightly lower due to rounding error.
    - `part_configs`: A list of configurations specifying type and relative number of graphs for each part of the
                      dataset.
    - `ground_truth_for_train_set`: Whether to calculate ground truth solutions for both the training and validation
                                    sets (`True`), or only the validation set (`False`).
                                    Setting this to `False` is useful if ground truth solutions are costly to compute
                                    and not required for training.
                                    Ignored if the graph generator already generates the ground truth solution with the
                                    graph.
    - `in_memory`: If `True`, the entire dataset will be stored in a single file on disk and loaded into memory all at
                   once.
                   If `False`, each graph will be stored in a separate file, and only the graphs that are currently
                   used are loaded into memory.
    """

    name: str
    num_train_graphs: int
    num_val_graphs: int
    part_configs: list[DatasetPartConfig]
    ground_truth_for_train_set: bool = True
    in_memory: bool = True


@yaml_object(YAML())
@dataclass()
class DatasetPartConfig:
    """
    Fields:

    - `weight_num_graphs`: The relative number of graphs in this part of the dataset.
                           The `weight_num_graphs`s for all parts of the dataset are normalised to sum up to one, then
                           multiplied with the total number of graphs in the dataset to calculate the number of graphs
                           in each part.
                           Must be greater than 0.
    - `graph_generator_config`: Specifies how to generate the graphs in this part of the dataset.
    - `ground_truth_solver_config`: If this is not `None`, a ground truth solution will be calculated as specified and
                                    added to the graph as `graph.y`.
                                    Ignored if the graph generator already generates the ground truth solution with the
                                    graph.
    """

    weight_num_graphs: float
    graph_generator_config: GraphGeneratorConfig
    ground_truth_solver_config: Optional[GroundTruthSolverConfig] = None
