import os
import time
from datetime import datetime

import torch
from torch_geometric.loader import DataLoader

from src.datasets.dataset_utils.dataset_constants import DATASET_MAPPING
from src.datasets.preprocessing.mapping import PRETRANSFORM_MAPPING, TRANSFORM_MAPPING
from src.evaluation.eval_scores import Scorer
from src.logger.logger import Logger
from src.models.mappings.wrapper_mapping import MODEL_WARPPER_MAPPING
from src.training.trainer import Trainer
from src.utils.path_io import get_path_up_to


def training(config: dict,
             run_name: str,
             device: torch.device) -> None:
    """
    Runs a training pipeline.

    This function takes a configuration dictionary, a run name, a model wrapper class, a dataset class,

    Args:
        config (dict): Configuration dictionary containing parameters for the training.
        run_name (str): Name of the run, used for logging and saving the model.
        model_parameters (dict): Parameters for the model wrapper class.
        model_wrapper_class (class): The class that wraps the model and provides methods for model creation.
        dataset (Dataset): The dataset to be used for training.
        device (torch.device): The device to run the training on (CPU or GPU).
        logger (Logger): Logger instance to log the training process.

    Returns:
        None
    """
    logger = Logger(name=run_name, current_time_string=datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    logger.write_dict(config)

    # Create graph dataset
    # Create pretransformer
    dataset_parameters = config.pop("dataset_parameters")
    if 'pre_transform' in dataset_parameters:
        pre_t_name = dataset_parameters['pre_transform'].pop('pre_t_class')
        pre_transform = PRETRANSFORM_MAPPING[pre_t_name](graph_name=dataset_parameters['name'],
                                                         **dataset_parameters.pop('pre_transform'))
    else:
        pre_transform = None

    if 'transform' in dataset_parameters:
        transform_params = dataset_parameters.pop('transform')
    else:
        transform_params = None


    # Extracting dataset parameters from config
    dataset_name = dataset_parameters.pop("class_name")
    process_dataset = dataset_parameters.pop("process")

    # Extracting model parameters from config
    model_parameters = config.pop("model_parameters")
    model_wrapper_class = MODEL_WARPPER_MAPPING[model_parameters.pop("wrapper_class")]

    # Create instance for dataset and process if given in config
    root_path = os.path.join(get_path_up_to(os.path.abspath(__file__), "src"), "data", "graphs")
    dataset = DATASET_MAPPING[dataset_name](root=root_path, pre_transform=pre_transform, **dataset_parameters)

    if process_dataset:
        start_time = time.time()
        dataset.process()
        end_time = time.time()
        logger.log_performance((end_time-start_time), 0, "preprocessing_time")
    scorer_params = config['evaluation_parameters'].pop("scorer_parameters", None)
    scorer = Scorer(**scorer_params)

    if 'output_dim' not in model_parameters.keys():
        model_parameters['output_dim'] = dataset.num_classes

    # get indices of train and test patients
    train_set, val_set, test_set, _ = dataset.split_data(**config.pop("data_split_parameters"))

    # Create torch data loaders
    batch_size = config["training_parameters"].pop('batch_size')
    train_loader = DataLoader(train_set, batch_size=batch_size)
    validation_loader = DataLoader(val_set, batch_size=batch_size)
    test_loader = DataLoader(test_set, batch_size=batch_size)

    if transform_params is not None:
        transform = TRANSFORM_MAPPING[transform_params.pop("t_class")](data_loader=train_loader)
        dataset.transform = transform

    model_wrapper = model_wrapper_class(run_name=run_name, **model_parameters)
    model = model_wrapper.create_model()
    logger.write_model(model)

    trainer = Trainer(
        training_loader=train_loader,
        validation_loader=validation_loader,
        model_wrapper=model_wrapper,
        device=device,
        logger=logger,
        scorer=scorer,
        **config["training_parameters"]
    )

    trainer.start_training()
    trainer.save_model()