import os

from ruamel.yaml import YAML
import torch

from constants import DATA_DIR
from util import get_timestamp, yaml_object_to_string
from ._dataset import Dataset


def save_in_memory_dataset(dataset: Dataset):
    """
    Saves the dataset in the data directory, in a single file named `dataset_name + ".pt"`.
    Also stores the dataset config as `dataset_name + "--config.yml"` in the same directory.

    Do not use this if `dataset.train_graphs` or `dataset.val_graphs` is of type `LargeDataset`!
    """
    os.makedirs(DATA_DIR, exist_ok=True)

    dataset_path = DATA_DIR / (dataset.config.name + ".pt")
    if dataset_path.exists():
        new_dataset_name = dataset.config.name + "--v" + get_timestamp()
        print(f'WARNING: Dataset "{dataset.config.name}.pt" already exists. Saving as "{new_dataset_name}.pt" instead')
        dataset_path = DATA_DIR / (new_dataset_name + ".pt")

    # save dataset (as a dict, because saving the object directly can lead to issues if the code changes between
    # saving and loading)
    dataset_dict = {
        "config": yaml_object_to_string(dataset.config),
        "train_graphs": dataset.train_graphs,
        "val_graphs": dataset.val_graphs,
    }
    torch.save(dataset_dict, dataset_path)
