# postponed evaluation of annotations
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path

from ruamel.yaml import YAML
import torch
from torch_geometric.data import Data

from constants import DATA_DIR
from ._dataset_config import DatasetConfig
from ._large_dataset import LargeDataset


@dataclass(frozen=True)
class Dataset:
    """
    A dataset of graphs, with train and val splits.
    Ground truth solutions may be included for all graphs or for none of the graphs, in the field `y` of the `Data`
    objects.

    There are two kinds of datasets:
    datasets in which all graphs are stored in a single file (in-memory datasets), and datasets in which each graph is
    stored in a separate file.
    For the first kind of dataset, `train_graphs` and `val_graphs` are lists of `Data` objects,
    meaning that all graphs are loaded into memory at once.
    For the second kind, only the graphs that are currently worked on are loaded.
    This makes the second kind more suitable for larger datasets that don't fit into memory.
    For this second kind, `train_graphs` and `val_graphs` are `LargeDataset` objects.
    """

    config: DatasetConfig
    train_graphs: list[Data] | LargeDataset
    val_graphs: list[Data] | LargeDataset

    def load(dataset_name: str) -> Dataset:
        """
        Loads the dataset with the given name from the data directory.
        """
        dataset_path_file = DATA_DIR / (dataset_name + ".pt")
        dataset_path_dir = DATA_DIR / dataset_name

        if dataset_path_file.is_file() and dataset_path_dir.is_dir():
            raise Exception(
                f'Dataset "{dataset_name}" exists both as a single file and as a directory. '
                'Delete or rename one of the two to resolve the ambiguity'
            )
        elif dataset_path_file.is_file():
            return _load_dataset_from_single_file(dataset_path_file)
        elif dataset_path_dir.is_dir():
            return _load_dataset_from_directory(dataset_path_dir)
        else:
            raise ValueError(f'Dataset "{dataset_name}" does not exist')


def _load_dataset_from_single_file(dataset_path: Path) -> Dataset:
    dataset_dict = torch.load(dataset_path, weights_only=False)
    return Dataset(
        config=YAML().load(dataset_dict["config"]),
        train_graphs=dataset_dict["train_graphs"],
        val_graphs=dataset_dict["val_graphs"],
    )


def _load_dataset_from_directory(dataset_path: Path) -> Dataset:
    config = YAML().load(dataset_path / "config.yml")
    return Dataset(
        config,
        train_graphs=LargeDataset(config, split="train"),
        val_graphs=LargeDataset(config, split="val"),
    )
