import os

from ruamel.yaml import YAML
import torch
from torch.nn import DataParallel, Module
from torch_geometric.loader import DataLoader

from constants import LOGS_DIR
from data_generation import Dataset
from util import get_timestamp
from ._loss_functions import BCEWithLogitsLossGraphTarget, IMLEKargerLoss, KargerDirectGradientLoss, ReinforceLoss, \
                             TSPDirectGradientLoss
from ._training_config import TrainingConfig


def data_loaders(training_config: TrainingConfig) -> list[tuple[DataLoader, DataLoader]]:
    """
    Loads a dataset from disk and returns
    1. A DataLoader for the training set
    2. A DataLoader for the validation set.
    """
    dataset = Dataset.load(training_config.dataset_name)
    print("Dataset:", dataset.config.name)

    num_chunks = training_config.num_evaluations_per_epoch
    train_graph_chunks = [dataset.train_graphs[start::num_chunks] for start in range(num_chunks)]
    val_graph_chunks = [dataset.val_graphs[start::num_chunks] for start in range(num_chunks)]

    train_loaders = []
    val_loaders = []

    for train_graphs, val_graphs in zip(train_graph_chunks, val_graph_chunks):
        train_loaders.append(DataLoader(
            train_graphs,
            training_config.batch_size,
            follow_batch=["edge_index"],
            shuffle=True,
        ))

        val_loaders.append(DataLoader(
            val_graphs,
            training_config.batch_size,
            follow_batch=["edge_index"],
        ))

    return list(zip(train_loaders, val_loaders))


def model(training_config: TrainingConfig) -> tuple[Module, torch.device]:
    """
    Initialises the model and returns
    1. The model
    2. A `torch.device` object as specified in the `training_config`.
    """
    device = torch.device(training_config.device)
    print("Device:", device)

    model = training_config.model_config.instantiate_model()

    if device.type == "cuda" and torch.cuda.device_count() > 1:
        model = DataParallel(model)
        print("GPUs available:", torch.cuda.device_count())

    model.to(device)
    print("Model:", training_config.model_config.MODEL_NAME)

    return model, device


def loss_function(training_config: TrainingConfig) -> Module:
    """
    Initialises the loss function based on `training_config.loss_function`.
    """
    print("Loss function:", training_config.loss_function)

    if training_config.loss_function == "supervised_binary_cross_entropy":
        return BCEWithLogitsLossGraphTarget(reduction="sum")
    elif training_config.loss_function in ["supervised_imle", "self_supervised_imle"]:
        assert training_config.imle_karger_config is not None
        mode = "supervised" if training_config.loss_function == "supervised_imle" else "self_supervised"
        return IMLEKargerLoss(training_config.imle_karger_config, mode, training_config.device)
    elif training_config.loss_function == "direct_gradient":
        return KargerDirectGradientLoss()
    elif training_config.loss_function == "tsp_direct_gradient":
        return TSPDirectGradientLoss()
    elif training_config.loss_function == "reinforce":
        return ReinforceLoss()
    else:
        raise ValueError(
            f'loss_function "{training_config.loss_function}" is not a valid loss function. '
            "See the documentation of TrainingConfig for valid options."
        )


def log_dir(training_config: TrainingConfig) -> str:
    """
    Creates a log directory, saves the training config there, then returns the path to the log directory.
    """
    # create log directory
    log_dir_name = f"{get_timestamp()}--{training_config.dataset_name}--{training_config.loss_function}-loss"
    log_dir_path = LOGS_DIR / log_dir_name
    os.makedirs(log_dir_path)

    # save training config
    with open(log_dir_path / "_training-config.yml", "w") as yml_file:
        YAML().dump(training_config, yml_file)

    print("Log directory name:", log_dir_name)

    return log_dir_path
