import os
import time

from ray import tune
import ray
import torch
from filelock import FileLock
import shutil

from src.datasets.dataset_utils.dataset_constants import DATASET_MAPPING
from src.datasets.preprocessing.mapping import PRETRANSFORM_MAPPING
from src.evaluation.eval_scores import Scorer
from src.logger.multi_instance_logger import MultiInstanceLogger
from src.models.mappings.wrapper_mapping import MODEL_WARPPER_MAPPING

from src.pipelines.cross_validation import cross_validation
from src.utils.config_utils import replace_value_in_config
from src.utils.path_io import get_path_up_to

REPO_ROOT = get_path_up_to(os.path.abspath(__file__), "src")

class ResourceCalculator:

    def __init__(self,
                 hidden_dim_mb = 20,
                 n_message_passing_mb = 20,
                 combined_mb = 20,
                 base_mb = 20,
                 vram = 49000
                 ):

        self.hidden_dim_mb = hidden_dim_mb
        self.n_message_passing_mb = n_message_passing_mb
        self.combined_mb = combined_mb
        self.base_mb = base_mb
        self.vram = vram

    def resources_per_trial(self, config):

        # Get hyperparameters from the config
        if 'hidden_dim' in config.keys():
            hidden_dim = config.get("hidden_dim")
        else:
            hidden_dim = 512
            print("[GRID SEARCH: 'hidden_dim' not in search space. Assume dim 512]")

        if 'n_message_passings' in config.keys():
            layer = config.get("n_message_passings")
        else:
            layer = 64
            print("[GRID SEARCH: 'n_message_passings' not in search space. Assume 64 layers]")

        # A simple logic to allocate GPU fraction
        required_mb = hidden_dim * self.hidden_dim_mb + layer * self.n_message_passing_mb + layer * hidden_dim * self.combined_mb + self.base_mb
        required_partitioning = min(1, required_mb / self.vram)

        return  {'gpu': required_partitioning}


class TuneObjective():

    def __init__(self, config: dict, run_name: str, device: torch.device, multi_instance_logger: MultiInstanceLogger):
        self.original_config = config
        self.run_name = run_name
        self.device = device
        self.multi_instance_logger = multi_instance_logger

    def __call__(self, config: dict):

        # Replace values in the original config with the new values from the config passed to the objective (for this
        # run) of the grid search
        config = replace_value_in_config(config=self.original_config, new_values=config)

        # Create logger instance
        logger = self.multi_instance_logger.next_logger()
        logger.write_dict(config)

        # Create graph dataset
        # Create pretransformer
        dataset_parameters = config.pop("dataset_parameters")
        if 'pre_transform' in dataset_parameters:
            pre_t_name = dataset_parameters['pre_transform'].pop('pre_t_class')
            pre_transform = PRETRANSFORM_MAPPING[pre_t_name](graph_name=dataset_parameters['name'],
                                                             **dataset_parameters.pop('pre_transform'))
        else:
            pre_transform = None

        if 'transform' in dataset_parameters:
            transform_params = dataset_parameters.pop('transform')
        else:
            transform_params = None

        # Extracting dataset parameters from config
        dataset_name = dataset_parameters.pop("class_name")
        process_dataset = dataset_parameters.pop("process")

        # Extracting model parameters from config
        model_parameters = config.pop("model_parameters")
        model_wrapper_class = MODEL_WARPPER_MAPPING[model_parameters.pop("wrapper_class")]

        # Create separate directory for each trail:
        base_root_path = os.path.join(get_path_up_to(os.path.abspath(__file__), "src"), "data", "graphs")
        trial_id = tune.get_context().get_trial_id()  # Correct way to get the trial ID

        dataset_subdir = f"ray_dir_{trial_id[-2:]}"
        # root_path = os.path.join(base_root_path, dataset_subdir)
        root_path = base_root_path

        # Cleanup if rerunning with same trial ID
        # if os.path.exists(root_path):
        #     shutil.rmtree(root_path)

        dataset = DATASET_MAPPING[dataset_name](
            root=root_path,
            pre_transform=pre_transform,
            multithreading_subdir=dataset_subdir,
            **dataset_parameters
        )

        if process_dataset:
            start_time = time.time()
            dataset.process()
            end_time = time.time()
            logger.log_performance((end_time - start_time), 0, "preprocessing_time")
        scorer_params = config['evaluation_parameters'].pop("scorer_parameters", None)
        scorer = Scorer(**scorer_params)

        if 'output_dim' not in model_parameters.keys():
            model_parameters['output_dim'] = dataset.num_classes

        print("##################### Device ####################")
        print(self.device)

        criterion = cross_validation(config=config,
                                     run_name=self.run_name,
                                     model_parameters=model_parameters,
                                     model_wrapper_class=model_wrapper_class,
                                     dataset=dataset,
                                     device=self.device,
                                     logger=logger,
                                     scorer=scorer,
                                     transform_params=transform_params)

        dataset.clear_multithread_subdir()

        self.multi_instance_logger.collect_final_results()
        self.multi_instance_logger.save_final_results()

        tune.report({"iterations": 1, "criterion": criterion})


def grid_search(config: dict,
                run_name: str,
                resources: dict,
                max_concurrent_trails: int,
                device: torch.device
                ):
    """
    Perform grid search over the specified hyperparameters.

    Args:
        config (dict): A dictionary containing the hyperparameters to search over.
        run_name (str): The name of the run, used for logging and saving the model.
        resources (dict): Dict with resources for each trial.
        max_concurrent_trails (int): maximal number of concurrent runs.
        device (torch.device): The device to run the training on (CPU or GPU).

    Returns:
        tune.Experiment: An experiment object for Ray Tune.
    """

    ray.init(configure_logging=False)

    search_space = config.pop("search_space")

    multi_instance_logger = MultiInstanceLogger(run_name)

    resource_type = resources.pop("resource_type") if "resource_type" in resources.keys() else "static"
    if resource_type == "dynamic":
        resources['vram'] = torch.cuda.get_device_properties(0).total_memory / 1000000
        res_calculator = ResourceCalculator(**resources)
        resources = res_calculator.resources_per_trial
    elif resource_type != "static":
        raise ValueError(f"[GRID SEARCH]: Resource type {resource_type} not defined.")

    tuner = tune.Tuner(
        trainable=tune.with_resources(
            tune.with_parameters(TuneObjective(config, run_name, device, multi_instance_logger=multi_instance_logger)),
            resources=resources
        ),
        tune_config=tune.TuneConfig(
            metric="criterion",
            mode="min",
            num_samples=1,
            max_concurrent_trials=max_concurrent_trails
        ),
        run_config=tune.RunConfig(
            failure_config=tune.FailureConfig(fail_fast=False),
            storage_path=os.path.join(REPO_ROOT, "runs")
        ),
        param_space={key: tune.grid_search(value) for key, value in search_space.items()},
    )
    results = tuner.fit()

    print("Best hyperparameters found were: ", results.get_best_result().config)

    multi_instance_logger.close()
