import os
import time

import torch
from torch.utils.data import Subset
from torch_geometric.loader import DataLoader

from src.datasets.dataset_utils.dataset_constants import DATASET_MAPPING
from src.datasets.preprocessing.mapping import PRETRANSFORM_MAPPING, TRANSFORM_MAPPING
from src.evaluation.eval_scores import Scorer
from src.logger.cross_val_logger import ManyFoldLogger
from src.models.mappings.wrapper_mapping import MODEL_WARPPER_MAPPING
from src.training.trainer import Trainer
from src.utils.path_io import get_path_up_to


def cross_validation(config: dict,
                     run_name: str,
                     model_parameters: dict,
                     model_wrapper_class,
                     dataset,
                     device: torch.device,
                     logger: ManyFoldLogger,
                     scorer: Scorer,
                     transform_params: dict=None) -> float:
    """
    Runs a cross-validation training pipeline.

    This function takes a configuration dictionary, a dataset, a model wrapper class, and a device,

    Args:
        config (dict): Configuration dictionary containing parameters for the training.
        run_name (str): Name of the run, used for logging and saving the model.
        model_parameters (dict): Parameters for the model wrapper class.
        model_wrapper_class (class): The class that wraps the model and provides methods for model creation.
        dataset (Dataset): The dataset to be used for training.
        device (torch.device): The device to run the training on (CPU or GPU).
        logger (Logger): Logger instance to log the training process.

    Returns:
        float: The mean results of the criterion of the cross-validation training.
    """

    data_split_parameters = config.pop("data_split_parameters")
    _, _, _, fold_indices = dataset.split_data(**data_split_parameters)

    batch_size = config["training_parameters"].pop('batch_size')
    results = []
    for fold_index, index_dict in fold_indices.items():
        print(f"[CROSS VALIDATION]: Starting training for fold {fold_index + 1}/{len(fold_indices)}")
        fold_logger = logger.new_fold()

        train_indices = index_dict['train']
        val_indices = index_dict['validation']

        dataset.prepare_fold(fold_index)

        # Create torch data loaders for the current fold
        train_set_fold = Subset(dataset, train_indices)
        val_set_fold = Subset(dataset, val_indices)

        train_loader = DataLoader(train_set_fold, batch_size=batch_size)
        validation_loader = DataLoader(val_set_fold, batch_size=batch_size)

        if transform_params is not None:
            transform = TRANSFORM_MAPPING[transform_params["t_class"]](data_loader=train_loader)
            dataset.transform = transform

        model_wrapper = model_wrapper_class(run_name=f"{run_name}_fold_{fold_index + 1}", **model_parameters)
        model = model_wrapper.create_model()
        fold_logger.write_model(model)

        trainer = Trainer(
            training_loader=train_loader,
            validation_loader=validation_loader,
            model_wrapper=model_wrapper,
            device=device,
            logger=fold_logger,
            scorer=scorer,
            **config["training_parameters"]
        )

        result = trainer.start_training()
        results.append(result)

    logger.summarize_folds()

    mean_results = sum(results) / len(results)

    return mean_results


def cross_validation_pipeline(config: dict,
                              run_name: str,
                              device: torch.device) -> None:
    logger = ManyFoldLogger(name=run_name)
    logger.write_dict(config)

    # Create graph dataset
    # Create pretransformer
    dataset_parameters = config.pop("dataset_parameters")
    if 'pre_transform' in dataset_parameters:
        pre_t_name = dataset_parameters['pre_transform'].pop('pre_t_class')
        pre_transform = PRETRANSFORM_MAPPING[pre_t_name](graph_name=dataset_parameters['name'],
                                                         **dataset_parameters.pop('pre_transform'))
    else:
        pre_transform = None

    if 'transform' in dataset_parameters:
        transform_params = dataset_parameters.pop('transform')
    else:
        transform_params = None


    # Extracting dataset parameters from config
    dataset_name = dataset_parameters.pop("class_name")
    process_dataset = dataset_parameters.pop("process")

    # Extracting model parameters from config
    model_parameters = config.pop("model_parameters")
    model_wrapper_class = MODEL_WARPPER_MAPPING[model_parameters.pop("wrapper_class")]

    # Create instance for dataset and process if given in config
    root_path = os.path.join(get_path_up_to(os.path.abspath(__file__), "src"), "data", "graphs")
    dataset = DATASET_MAPPING[dataset_name](root=root_path, pre_transform=pre_transform, **dataset_parameters)

    if process_dataset:
        start_time = time.time()
        dataset.process()
        end_time = time.time()
        logger.log_performance((end_time-start_time), 0, "preprocessing_time")

    scorer_params = config['evaluation_parameters'].pop("scorer_parameters", None)
    scorer = Scorer(**scorer_params)

    if 'output_dim' not in model_parameters.keys():
        model_parameters['output_dim'] = dataset.num_classes

    cross_validation(config=config,
                     run_name=run_name,
                     model_parameters=model_parameters,
                     model_wrapper_class=model_wrapper_class,
                     transform_params=transform_params,
                     dataset=dataset,
                     device=device,
                     logger=logger,
                     scorer=scorer)
