
import os
import time
from typing import Final

from ray import tune
import ray
import torch
from torch_geometric.data import DataLoader
import shutil

from pipelines.grid_search import ResourceCalculator
from src.pipelines.evaluation import evaluation
from src.datasets.dataset_utils.dataset_constants import DATASET_MAPPING
from src.datasets.preprocessing.mapping import PRETRANSFORM_MAPPING, TRANSFORM_MAPPING
from src.evaluation.eval_scores import Scorer
from src.logger.multi_instance_logger import MultiInstanceLogger
from src.models.mappings.wrapper_mapping import MODEL_WARPPER_MAPPING
from src.utils.config_utils import replace_value_in_config
from src.utils.path_io import get_path_up_to
from src.training.trainer import Trainer

TEST_FOLDS: Final[int] = 10
TEST_INITIALISATION: Final[int] = 3

REPO_ROOT = get_path_up_to(os.path.abspath(__file__), "src")

class EricaTuneObjective():

    def __init__(self, config: dict,
                 run_name: str,
                 device: torch.device,
                 multi_instance_logger: MultiInstanceLogger,
                 test_fold: int = 0):
        self.original_config = config
        self.run_name = run_name
        self.device = device
        self.multi_instance_logger = multi_instance_logger
        self.fold_index = test_fold

    def __call__(self, config: dict):
        # Replace values in the original config with the new values from the config passed to the objective (for this
        # run) of the grid search
        config = replace_value_in_config(config=self.original_config, new_values=config)

        # Create logger instance
        logger = self.multi_instance_logger.next_logger()
        logger.write_dict(config)

        # Create graph dataset
        # Create pretransformer
        dataset_parameters = config.pop("dataset_parameters")
        if 'pre_transform' in dataset_parameters:
            pre_t_name = dataset_parameters['pre_transform'].pop('pre_t_class')
            pre_transform = PRETRANSFORM_MAPPING[pre_t_name](graph_name=dataset_parameters['name'],
                                                             **dataset_parameters.pop('pre_transform'))
        else:
            pre_transform = None

        if 'transform' in dataset_parameters:
            transform_params = dataset_parameters.pop('transform')
        else:
            transform_params = None

        # Extracting dataset parameters from config
        dataset_name = dataset_parameters.pop("class_name")
        process_dataset = dataset_parameters.pop("process")

        # Extracting model parameters from config
        model_parameters = config.pop("model_parameters")
        model_wrapper_class = MODEL_WARPPER_MAPPING[model_parameters.pop("wrapper_class")]

        # Create separate directory for each trail:
        base_root_path = os.path.join(get_path_up_to(os.path.abspath(__file__), "src"), "data", "graphs")
        trial_id = tune.get_context().get_trial_id()  # Correct way to get the trial ID

        dataset_subdir = f"ray_dir_{trial_id[-2:]}"
        root_path = os.path.join(base_root_path, dataset_subdir)

        # Cleanup if rerunning with same trial ID
        if os.path.exists(root_path):
            shutil.rmtree(root_path)

        dataset = DATASET_MAPPING[dataset_name](root=root_path, pre_transform=pre_transform, **dataset_parameters)

        if process_dataset:
            start_time = time.time()
            dataset.process()
            end_time = time.time()
            logger.log_performance((end_time - start_time), 0, "preprocessing_time")
        scorer_params = config['evaluation_parameters'].pop("scorer_parameters", None)
        scorer = Scorer(**scorer_params)

        if 'output_dim' not in model_parameters.keys():
            model_parameters['output_dim'] = dataset.num_classes

        # get indices of train and test patients
        config["data_split_parameters"]["n_folds"] = TEST_FOLDS
        _ = dataset.split_data(**config.pop("data_split_parameters"))
        train_set, val_set, test_set = dataset.prepare_fold(self.fold_index)

        # Create torch data loaders
        batch_size = config["training_parameters"].pop('batch_size')
        train_loader = DataLoader(train_set, batch_size=batch_size)
        validation_loader = DataLoader(val_set, batch_size=batch_size)
        test_loader = DataLoader(test_set, batch_size=batch_size)

        if transform_params is not None:
            transform = TRANSFORM_MAPPING[transform_params.pop("t_class")](data_loader=train_loader)
            dataset.transform = transform

        model_wrapper = model_wrapper_class(run_name=self.run_name, **model_parameters)
        model = model_wrapper.create_model()
        logger.write_model(model)

        print("##################### Device ####################")
        print(self.device)

        fold_logger = logger.new_fold()

        trainer = Trainer(
            training_loader=train_loader,
            validation_loader=validation_loader,
            model_wrapper=model_wrapper,
            device=self.device,
            logger=fold_logger,
            scorer=scorer,
            **config["training_parameters"]
        )

        criterion = trainer.start_training()

        self.multi_instance_logger.collect_final_results()
        self.multi_instance_logger.save_final_results()

        tune.report({"iterations": 1, "criterion": criterion})


def errica_cross_testing(config: dict,
                         run_name: str,
                         resources: dict,
                         max_concurrent_trails: int,
                         device: torch.device):
    """
    Perform grid search over the specified hyperparameters.

    Args:
        config (dict): A dictionary containing the hyperparameters to search over.
        run_name (str): The name of the run, used for logging and saving the model.
        resources (dict): Dict with resources for each trial.
        max_concurrent_trails (int): maximal number of concurrent runs.
        device (torch.device): The device to run the training on (CPU or GPU).

    Returns:
        tune.Experiment: An experiment object for Ray Tune.
    """

    ray.init(configure_logging=False,
             _temp_dir=get_path_up_to(os.path.abspath(__file__),'generic_gnn_project'))
    search_space = config.pop("search_space")

    resource_type = resources.pop("resource_type") if "resource_type" in resources.keys() else "static"
    if resource_type == "dynamic":
        res_calculator = ResourceCalculator(**resources)
        resources = res_calculator.resources_per_trial
    elif resource_type != "static":
        raise ValueError(f"[GRID SEARCH]: Resource type {resource_type} not defined.")

    for i in range(TEST_FOLDS):
        multi_instance_logger = MultiInstanceLogger(run_name)

        tuner = tune.Tuner(
            trainable=tune.with_resources(
                tune.with_parameters(EricaTuneObjective(config,
                                                        run_name,
                                                        device,
                                                        multi_instance_logger=multi_instance_logger,
                                                        test_fold=i)),
                resources=resources
            ),
            tune_config=tune.TuneConfig(
                metric="criterion",
                mode="min",
                num_samples=1,
                max_concurrent_trials=max_concurrent_trails
            ),
            run_config=tune.RunConfig(
                failure_config=tune.FailureConfig(fail_fast=True),
                storage_path=os.path.join(REPO_ROOT, "runs")
            ),
            param_space={key: tune.grid_search(value) for key, value in search_space.items()},
        )
        results = tuner.fit()

        print("Best hyperparameters found were: ", results.get_best_result().config)

        multi_instance_logger.close()

        optimal_params = results.get_best_result().config
        optimal_config = replace_value_in_config(config, optimal_params)
        optimal_config["evaluation_parameters"]["n_test_init"] = TEST_INITIALISATION

        evaluation(config=optimal_config,
                   run_name=f'{run_name}_evaluation',
                   device=device)
