import os
import time
from datetime import datetime

import torch
from torch.utils.data import Subset
from torch_geometric.loader import DataLoader

from src.datasets.dataset_utils.dataset_constants import DATASET_MAPPING
from src.datasets.preprocessing.mapping import PRETRANSFORM_MAPPING, TRANSFORM_MAPPING
from src.evaluation.eval_measurements import Measurements
from src.evaluation.eval_scores import Scorer
from src.logger.cross_val_logger import ManyFoldLogger
from src.logger.logger import Logger
from src.models.mappings.wrapper_mapping import MODEL_WARPPER_MAPPING
from src.training.trainer import Trainer
from src.evaluation.evaluator import Evaluator
from src.utils.path_io import get_path_up_to
from src.utils.seed import seed_step

REPO_ROOT = get_path_up_to(os.path.abspath(__file__), "src")


def evaluation_cross_val(config: dict,
                         run_name: str,
                         device: torch.device) -> None:
    logger = ManyFoldLogger(name=run_name)
    logger.write_dict(config)

    # Create graph dataset
    # Create pretransformer
    dataset_parameters = config.pop("dataset_parameters")
    if 'pre_transform' in dataset_parameters:
        pre_t_name = dataset_parameters['pre_transform'].pop('pre_t_class')
        pre_transform = PRETRANSFORM_MAPPING[pre_t_name](graph_name=dataset_parameters['name'],
                                                         **dataset_parameters.pop('pre_transform'))
    else:
        pre_transform = None

    if 'transform' in dataset_parameters:
        transform_params = dataset_parameters.pop('transform')
    else:
        transform_params = None

    # Extracting dataset parameters from config
    dataset_name = dataset_parameters.pop("class_name")
    process_dataset = dataset_parameters.pop("process")

    # Extracting model parameters from config
    model_parameters = config.pop("model_parameters")
    model_wrapper_class = MODEL_WARPPER_MAPPING[model_parameters.pop("wrapper_class")]

    # Create instance for dataset and process if given in config
    root_path = os.path.join(get_path_up_to(os.path.abspath(__file__), "src"), "data", "graphs")
    dataset = DATASET_MAPPING[dataset_name](root=root_path, pre_transform=pre_transform, **dataset_parameters)

    if process_dataset:
        start_time = time.time()
        dataset.process()
        end_time = time.time()
        logger.log_performance((end_time - start_time), 0, "preprocessing_time")

    scorer_params = config['evaluation_parameters'].pop("scorer_parameters", None)
    scorer = Scorer(**scorer_params)

    if "measure_params" in config['evaluation_parameters'].keys():
        measure_params = config['evaluation_parameters'].pop("measure_params", None)
        measurements = Measurements(measures=measure_params)
    else:
        measurements = None

    if 'output_dim' not in model_parameters.keys():
        model_parameters['output_dim'] = dataset.num_classes

    # Data splitting
    data_split_parameters = config.pop("data_split_parameters")
    data_split_parameters['n_folds'] = config['evaluation_parameters']['n_test_init']
    _, _, test_set, fold_indices = dataset.split_data(**data_split_parameters)
    batch_size = config["training_parameters"].pop('batch_size')
    test_loader = DataLoader(test_set, batch_size=batch_size)

    n_test_initialisations = config['evaluation_parameters'].pop('n_test_init')

    # Create results file
    results_file_path = os.path.join(REPO_ROOT, 'data', 'output', 'eval_results', f'results_{run_name}.csv')
    if not os.path.exists(os.path.dirname(results_file_path)):
        os.makedirs(os.path.dirname(results_file_path))

    for i in range(n_test_initialisations):

        # Create torch data loaders for the current fold
        dataset.prepare_fold(i)
        index_dict = fold_indices[i]
        train_indices = index_dict['train'] if 'train' in index_dict.keys() else index_dict['model_selection'][0]['train']
        val_indices = index_dict['validation'] if 'validaton' in index_dict.keys() else index_dict['model_selection'][0]['validation']
        train_set_fold = Subset(dataset, train_indices)
        val_set_fold = Subset(dataset, val_indices)
        train_loader = DataLoader(train_set_fold, batch_size=batch_size)
        val_loader = DataLoader(val_set_fold, batch_size=batch_size)

        if transform_params is not None:
            transform = TRANSFORM_MAPPING[transform_params["t_class"]](data_loader=train_loader)
            dataset.transform = transform

        fold_logger = logger.new_fold()

        model_wrapper = model_wrapper_class(run_name=f"{run_name}_init_{i}", **model_parameters)
        model = model_wrapper.create_model()
        fold_logger.write_model(model)

        print("[Evaluation]: Train model on train set")

        trainer = Trainer(
            training_loader=train_loader,
            validation_loader=val_loader,
            model_wrapper=model_wrapper,
            device=device,
            logger=Logger(name=f"{run_name}_init_{i}",
                          current_time_string=datetime.now().strftime("%Y-%m-%d_%H-%M-%S")),
            scorer=scorer,
            **config["training_parameters"]
        )
        trainer.start_training()

        print("[Evaluation]: Training finished, run inference on test set.")

        evaluator = Evaluator(test_loader=test_loader,
                              model_wrapper=model_wrapper,
                              device=device,
                              logger=fold_logger,
                              scorer=scorer,
                              measurements=measurements,
                              train_loader=train_loader,
                              val_loader=val_loader)

        _, results = evaluator.evaluate()
        fold_logger.close()

        # Increasing torch seed by one
        seed_step(device)

    logger.save_test_scores(results_file_path)
