import os
from pathlib import Path
from typing import Optional

from ruamel.yaml import YAML
import torch
from torch_geometric.data import Data
from tqdm import tqdm

from constants import DATA_DIR, TQDM_OPTIONS
from ._dataset import Dataset
from ._dataset_config import DatasetConfig, DatasetPartConfig
from ._save_in_memory_dataset import save_in_memory_dataset
from ._util import graph_file_name


def generate_dataset(dataset_config: DatasetConfig):
    """
    Generates a `Dataset` including training and validation graphs.
    The dataset and a YAML file containing the `dataset_config` are saved in the `data/` directory, as
    `dataset_config.name + ".pt"` and `dataset_config.name + "--config.yml"`, respectively.
    """
    print("Generating training set...")
    train_graphs = _generate_dataset_parts(dataset_config, validation_set=False)
    print("Generating validation set...")
    val_graphs = _generate_dataset_parts(dataset_config, validation_set=True)

    if dataset_config.in_memory:
        dataset = Dataset(config=dataset_config, train_graphs=train_graphs, val_graphs=val_graphs)
        save_in_memory_dataset(dataset)
        config_path = DATA_DIR / (dataset_config.name + "--config.yml")
    else:
        # the graphs have already been saved, we only need to save the config
        config_path = DATA_DIR / dataset_config.name / "config.yml"

    with open(config_path, "w") as yml_file:
        YAML().dump(dataset_config, yml_file)


def _generate_dataset_parts(dataset_config: DatasetConfig, validation_set: bool) -> list[Data]:
    if validation_set:
        num_graphs_total = dataset_config.num_val_graphs
        calculate_ground_truth = True
    else:
        num_graphs_total = dataset_config.num_train_graphs
        calculate_ground_truth = dataset_config.ground_truth_for_train_set

    if not dataset_config.in_memory:
        output_dir = DATA_DIR / dataset_config.name / ("val" if validation_set else "train")
        os.makedirs(output_dir, exist_ok=True)
    else:
        # don't save graphs in separate files
        output_dir = None

    part_weigths = [part_config.weight_num_graphs for part_config in dataset_config.part_configs]
    part_weight_sum = sum(part_weigths)
    num_graphs_per_part = [int(num_graphs_total * part_weight / part_weight_sum) for part_weight in part_weigths]

    dataset = []
    offset = 0
    for part, (part_config, num_graphs_part) in enumerate(zip(dataset_config.part_configs, num_graphs_per_part)):
        print(f"Part {part + 1}/{len(dataset_config.part_configs)}")
        dataset += _generate_dataset_part(part_config, num_graphs_part, calculate_ground_truth, output_dir, offset)
        offset += num_graphs_part

    return dataset


def _generate_dataset_part(
    part_config: DatasetPartConfig,
    num_graphs: int,
    calculate_ground_truth: bool,
    output_dir: Optional[Path],
    offset: int,
) -> list[Data]:
    """
    Generates `num_graphs` graphs according to `part_config`.
    If `calculate_ground_truth` is `True`, ground truth solutions are also generated according to `part_config`.

    If `output_dir` is given, the graphs are stored in individual files inside that directory and an empty list is
    returned.
    If `offset` is given in addition to `output_dir`, the number in each graph's file name is increased
    by that offset.
    If the dataset consists of multiple parts, this should be used to avoid file name collisions between parts.

    If `output_dir` is not given, a list containing the generated graphs is returned and `offset` is ignored.
    """
    graphs = []

    for i in tqdm(range(num_graphs), **TQDM_OPTIONS):
        file_name = graph_file_name(i + offset)

        if output_dir is not None and (output_dir / file_name).exists():
            print(f"Skipping {file_name}, because the file already exists")
            continue

        graph = _generate_graph(part_config, calculate_ground_truth)

        if output_dir is not None:
            torch.save(graph, output_dir / file_name)
        else:
            graphs.append(graph)

    return graphs


def _generate_graph(part_config: DatasetPartConfig, calculate_ground_truth: bool) -> Data:
    """
    Generates a single graph according to `part_config`.
    If `calculate_ground_truth` is `True`, the ground truth solution is also generated according to `part_config`.
    """
    graph = part_config.graph_generator_config.generate_graph()

    if calculate_ground_truth and graph.y is None and part_config.ground_truth_solver_config is not None:
        # calculate ground truth solution
        graph.y = part_config.ground_truth_solver_config.solve(graph)

    return graph


if __name__ == "__main__":
    from ._graph_generators import TSPGraphConfig
    from ._ground_truth_solvers import TSPSolverConfig

    DATASET_CONFIG = DatasetConfig(
        name="tsp--n-20",
        num_train_graphs=10_000,
        num_val_graphs=1_000,
        part_configs=[
            DatasetPartConfig(
                weight_num_graphs=1,
                graph_generator_config=TSPGraphConfig(
                    num_nodes=20,
                ),
                ground_truth_solver_config=TSPSolverConfig(),
            ),
        ],
        ground_truth_for_train_set=True,
        in_memory=False,
    )
    generate_dataset(DATASET_CONFIG)
