import copy
import os
import shutil
import time
from datetime import datetime
import pickle

import torch
from ray import tune
import ray
from torch.utils.data import Subset
from torch_geometric.loader import DataLoader

from src.datasets.dataset_utils.dataset_constants import DATASET_MAPPING
from src.datasets.preprocessing.mapping import PRETRANSFORM_MAPPING, TRANSFORM_MAPPING
from src.evaluation.eval_measurements import Measurements
from src.evaluation.eval_scores import Scorer
from src.logger.cross_val_logger import ManyFoldLogger
from src.logger.logger import Logger
from src.models.mappings.wrapper_mapping import MODEL_WARPPER_MAPPING
from src.training.trainer import Trainer
from src.evaluation.evaluator import Evaluator
from src.utils.path_io import get_path_up_to
from src.utils.seed import set_seed

REPO_ROOT = get_path_up_to(os.path.abspath(__file__), "src")


class TuneObjective:

    def __init__(self, config: dict, run_name: str, device: torch.device, start_time):
        self.original_config = config
        self.run_name = run_name
        self.device = device
        self.logger_save_path = os.path.join(REPO_ROOT, 'data')
        self.start_time = start_time

    def __call__(self, config: dict, *args, **kwargs):

        i = config['init']
        logger = Logger(name=f"{self.run_name}_fold_{i}",
                        current_time_string=self.start_time)

        set_seed(torch.initial_seed() + i, self.device)

        # Create graph dataset
        # Create pretransformer
        dataset_parameters = copy.deepcopy(self.original_config["dataset_parameters"])
        if 'pre_transform' in dataset_parameters:
            pre_t_name = dataset_parameters['pre_transform'].pop('pre_t_class')
            pre_transform = PRETRANSFORM_MAPPING[pre_t_name](graph_name=dataset_parameters['name'],
                                                             **dataset_parameters.pop('pre_transform'))
        else:
            pre_transform = None

        if 'transform' in dataset_parameters:
            transform_params = dataset_parameters.pop('transform')
        else:
            transform_params = None

        # Extracting dataset parameters from config
        dataset_name = dataset_parameters.pop("class_name")
        process_dataset = dataset_parameters.pop("process")

        # Extracting model parameters from config
        model_parameters = copy.deepcopy(self.original_config["model_parameters"])
        model_wrapper_class = MODEL_WARPPER_MAPPING[model_parameters.pop("wrapper_class")]

        # Create instance for dataset and process if given in config
        base_root_path = os.path.join(get_path_up_to(os.path.abspath(__file__), "src"), "data", "graphs")
        trial_id = tune.get_context().get_trial_id()  # Correct way to get the trial ID

        dataset_subdir = f"ray_dir_{trial_id[-2:]}"
        root_path = os.path.join(base_root_path, dataset_subdir)
        # root_path = base_root_path

        # Cleanup if rerunning with same trial ID
        if os.path.exists(root_path):
            shutil.rmtree(root_path)

        dataset = DATASET_MAPPING[dataset_name](root=root_path,
                                                pre_transform=pre_transform,
                                                # multithreading_subdir=dataset_subdir,
                                                **dataset_parameters)

        if process_dataset:
            start_time = time.time()
            dataset.process()
            end_time = time.time()
            logger.log_performance((end_time - start_time), 0, "preprocessing_time")

        scorer_params = self.original_config['evaluation_parameters']["scorer_parameters"]
        scorer = Scorer(**scorer_params)

        if "measure_params" in self.original_config['evaluation_parameters'].keys():
            measure_params = self.original_config['evaluation_parameters']["measure_params"]
            measurements = Measurements(measures=measure_params)
        else:
            measurements = None

        if 'output_dim' not in model_parameters.keys():
            model_parameters['output_dim'] = dataset.num_classes

        # Data splitting
        data_split_parameters = self.original_config["data_split_parameters"]
        data_split_parameters['n_folds'] = self.original_config['evaluation_parameters']['n_test_init']
        _, _, test_set, fold_indices = dataset.split_data(**data_split_parameters)
        batch_size = self.original_config["training_parameters"]['batch_size']
        test_loader = DataLoader(test_set, batch_size=batch_size)

        # Create results file
        results_file_path = os.path.join(REPO_ROOT, 'data', 'output', 'eval_results', f'results_{self.run_name}.csv')
        if not os.path.exists(os.path.dirname(results_file_path)):
            os.makedirs(os.path.dirname(results_file_path))

        # Create torch data loaders for the current fold
        dataset.prepare_fold(i)
        index_dict = fold_indices[i]
        train_indices = index_dict['train'] if 'train' in index_dict.keys() else index_dict['model_selection'][0][
            'train']
        val_indices = index_dict['validation'] if 'validation' in index_dict.keys() else \
            index_dict['model_selection'][0]['validation']
        train_set_fold = Subset(dataset, train_indices)
        val_set_fold = Subset(dataset, val_indices)
        train_loader = DataLoader(train_set_fold, batch_size=batch_size)
        val_loader = DataLoader(val_set_fold, batch_size=batch_size)

        if transform_params is not None:
            transform = TRANSFORM_MAPPING[transform_params["t_class"]](data_loader=train_loader)
            dataset.transform = transform

        model_wrapper = model_wrapper_class(run_name=f"{self.run_name}_init_{i}", **model_parameters)
        model = model_wrapper.create_model()
        logger.write_model(model)

        print("[Evaluation]: Train model on train set")

        trainer = Trainer(
            training_loader=train_loader,
            validation_loader=val_loader,
            model_wrapper=model_wrapper,
            device=self.device,
            logger=Logger(name=f"{self.run_name}_init_{i}",
                          current_time_string=datetime.now().strftime("%Y-%m-%d_%H-%M-%S")),
            scorer=scorer,
            **self.original_config["training_parameters"]
        )
        trainer.start_training()

        print("[Evaluation]: Training finished, run inference on test set.")

        evaluator = Evaluator(test_loader=test_loader,
                              model_wrapper=model_wrapper,
                              device=self.device,
                              logger=logger,
                              scorer=scorer,
                              measurements=measurements,
                              train_loader=train_loader,
                              val_loader=val_loader)

        _, results = evaluator.evaluate()

        logger.close_sr()

        save_path = os.path.join(self.logger_save_path, f'fold_logger_{i}.pkl')
        with open(save_path, "wb") as file:
            pickle.dump(logger, file)


def evaluation_ray_tune(config: dict,
                        run_name: str,
                        resources: dict,
                        max_concurrent_trails: int,
                        device: torch.device
                        ) -> None:
    start_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    logger = ManyFoldLogger(name=run_name, current_time_string=start_time)
    logger.write_dict(config)

    ray.init(configure_logging=False)

    search_space = {"init": list(range(config['evaluation_parameters']['n_test_init']))}

    vram = torch.cuda.get_device_properties(0).total_memory / 1000000
    # gpu_share = resources['gpu_share'] if 'gpu_share' in resources else 1.0
    min_mb = resources['min_mb'] if 'min_mb' in resources else 0
    resources = {'gpu': round(min(1, (max(resources['max_mb'] * 1.2 + 1000, min_mb) / vram)), 4)}
    # max_concurrent = gpu_share // resources['gpu']

    # print(f"res gpu: {resources['gpu']}")
    # print(f"SHARE: {gpu_share}")
    # print(f"max curr: {max_concurrent}")

    tuner = tune.Tuner(
        trainable=tune.with_resources(
            tune.with_parameters(TuneObjective(config, run_name, device, start_time=start_time)),
            resources=resources
        ),
        tune_config=tune.TuneConfig(
            metric="criterion",
            mode="min",
            num_samples=1,
            max_concurrent_trials=max_concurrent_trails
        ),
        run_config=tune.RunConfig(
            failure_config=tune.FailureConfig(fail_fast=True),
            storage_path=os.path.join(REPO_ROOT, "runs")
        ),
        param_space={key: tune.grid_search(value) for key, value in search_space.items()},
    )

    results = tuner.fit()

    for i in search_space['init']:
        path = os.path.join(REPO_ROOT, 'data', f'fold_logger_{i}.pkl')
        with open(path, 'rb') as file:
            fold_logger = pickle.load(file)

        fold_logger.close()

        logger.add_fold(fold_logger)
        os.remove(path)

    results_file_path = os.path.join(REPO_ROOT, 'data', 'output', 'eval_results', f'results_{run_name}.csv')
    logger.save_test_scores(results_file_path)
