import copy
import os
import shutil
import time
from datetime import datetime
import pickle

import torch
import yaml
from ray import tune
import ray
from torch.utils.data import Subset
from torch_geometric.loader import DataLoader

from src.pipelines.grid_search import ResourceCalculator
from src.logger.multi_instance_logger import MultiInstanceLogger
from src.datasets.dataset_utils.dataset_constants import DATASET_MAPPING
from src.datasets.preprocessing.mapping import PRETRANSFORM_MAPPING, TRANSFORM_MAPPING
from src.evaluation.eval_measurements import Measurements
from src.evaluation.eval_scores import Scorer
from src.logger.cross_val_logger import ManyFoldLogger
from src.logger.logger import Logger
from src.models.mappings.wrapper_mapping import MODEL_WARPPER_MAPPING
from src.training.trainer import Trainer
from src.evaluation.evaluator import Evaluator
from src.utils.path_io import get_path_up_to
from src.utils.seed import set_seed

REPO_ROOT = get_path_up_to(os.path.abspath(__file__), "src")


def get_resources(config: str) -> dict:
    with open(config['path'], "r", encoding="utf-8") as f:
        config = yaml.safe_load(f)

    vram = torch.cuda.get_device_properties(0).total_memory
    return {'gpu': round(min(1, config['resources']['max_mb'] * 3.5 / (vram / 1000000)), 4)}


class TuneObjective:

    def __init__(self, config: dict, run_name: str, device: torch.device, multi_instance_logger):
        self.original_config = config
        self.run_name = run_name
        self.device = device
        self.multi_instance_logger = multi_instance_logger
        self.logger_save_path = os.path.join(REPO_ROOT, 'data')

    def __call__(self, config: dict, *args, **kwargs):

        print("test")

        with open(config['path'], "r", encoding="utf-8") as f:
            config = yaml.safe_load(f)

        i = int(config['run_name'][-1])
        run_name = config['run_name']

        # Create logger instance
        logger = self.multi_instance_logger.next_logger()
        logger.write_dict(config)
        fold_logger = logger.new_fold()

        set_seed(torch.initial_seed() + i, self.device)

        # Create graph dataset
        # Create pretransformer
        dataset_parameters = copy.deepcopy(config["dataset_parameters"])
        if 'pre_transform' in dataset_parameters:
            pre_t_name = dataset_parameters['pre_transform'].pop('pre_t_class')
            pre_transform = PRETRANSFORM_MAPPING[pre_t_name](graph_name=dataset_parameters['name'],
                                                             **dataset_parameters.pop('pre_transform'))
        else:
            pre_transform = None

        if 'transform' in dataset_parameters:
            transform_params = dataset_parameters.pop('transform')
        else:
            transform_params = None

        # Extracting dataset parameters from config
        dataset_name = dataset_parameters.pop("class_name")
        process_dataset = dataset_parameters.pop("process")

        # Extracting model parameters from config
        model_parameters = copy.deepcopy(config["model_parameters"])
        model_wrapper_class = MODEL_WARPPER_MAPPING[model_parameters.pop("wrapper_class")]

        # Create instance for dataset and process if given in config
        base_root_path = os.path.join(get_path_up_to(os.path.abspath(__file__), "src"), "data", "graphs")
        trial_id = tune.get_context().get_trial_id()  # Correct way to get the trial ID

        dataset_subdir = f"ray_dir_{trial_id[-2:]}"
        root_path = os.path.join(base_root_path, dataset_subdir)
        # Cleanup if rerunning with same trial ID
        if os.path.exists(root_path):
            shutil.rmtree(root_path)
        dataset = DATASET_MAPPING[dataset_name](root=root_path, pre_transform=pre_transform, **dataset_parameters)

        if process_dataset:
            start_time = time.time()
            dataset.process()
            end_time = time.time()
            fold_logger.log_performance((end_time - start_time), 0, "preprocessing_time")

        scorer_params = config['evaluation_parameters']["scorer_parameters"]
        scorer = Scorer(**scorer_params)

        if "measure_params" in config['evaluation_parameters'].keys():
            measure_params = config['evaluation_parameters']["measure_params"]
            measurements = Measurements(measures=measure_params)
        else:
            measurements = None

        if 'output_dim' not in model_parameters.keys():
            model_parameters['output_dim'] = dataset.num_classes

        # Data splitting
        data_split_parameters = config["data_split_parameters"]
        data_split_parameters['n_folds'] = config['evaluation_parameters']['n_test_init']
        _, _, test_set, fold_indices = dataset.split_data(**data_split_parameters)
        batch_size = config["training_parameters"]['batch_size']
        test_loader = DataLoader(test_set, batch_size=batch_size)

        # Create results file
        results_file_path = os.path.join(REPO_ROOT, 'data', 'output', 'eval_results', f'{self.run_name[:-1]}.csv')
        if not os.path.exists(os.path.dirname(results_file_path)):
            os.makedirs(os.path.dirname(results_file_path))

        # Create torch data loaders for the current fold
        dataset.prepare_fold(i)
        index_dict = fold_indices[i]
        train_indices = index_dict['train'] if 'train' in index_dict.keys() else index_dict['model_selection'][0][
            'train']
        val_indices = index_dict['validation'] if 'validaton' in index_dict.keys() else \
            index_dict['model_selection'][0]['validation']
        train_set_fold = Subset(dataset, train_indices)
        val_set_fold = Subset(dataset, val_indices)
        train_loader = DataLoader(train_set_fold, batch_size=batch_size)
        val_loader = DataLoader(val_set_fold, batch_size=batch_size)

        if transform_params is not None:
            transform = TRANSFORM_MAPPING[transform_params["t_class"]](data_loader=train_loader)
            dataset.transform = transform

        model_wrapper = model_wrapper_class(run_name=f"{run_name}_init_{i}", **model_parameters)
        model = model_wrapper.create_model()
        fold_logger.write_model(model)

        print("[Evaluation]: Train model on train set")

        trainer = Trainer(
            training_loader=train_loader,
            validation_loader=val_loader,
            model_wrapper=model_wrapper,
            device=self.device,
            logger=Logger(name=f"{run_name}_init_{i}",
                          current_time_string=datetime.now().strftime("%Y-%m-%d_%H-%M-%S")),
            scorer=scorer,
            **config["training_parameters"]
        )
        trainer.start_training()

        print("[Evaluation]: Training finished, run inference on test set.")

        evaluator = Evaluator(test_loader=test_loader,
                              model_wrapper=model_wrapper,
                              device=self.device,
                              logger=fold_logger,
                              scorer=scorer,
                              measurements=measurements,
                              train_loader=train_loader,
                              val_loader=val_loader)

        _, results = evaluator.evaluate()

        fold_logger.close()

        logger.save_test_scores(results_file_path)

        self.multi_instance_logger.collect_final_results()
        self.multi_instance_logger.save_final_results()


def errica_evaluation_ray_tune(config: dict,
                               run_name: str,
                               resources: dict,
                               max_concurrent_trails: int,
                               device: torch.device
                               ) -> None:
    multi_instance_logger = MultiInstanceLogger(run_name)

    ray.init(configure_logging=False)

    dir_path = os.path.join(REPO_ROOT, config['dir_path'])

    search_space = {"path": [os.path.join(dir_path, path) for path in os.listdir(config['dir_path'])]}

    resource_type = resources.pop("resource_type") if "resource_type" in resources.keys() else "static"
    if resource_type == "dynamic":
        res_calculator = ResourceCalculator(**resources)
        resources = res_calculator.resources_per_trial
    elif resource_type != "static":
        raise ValueError(f"[GRID SEARCH]: Resource type {resource_type} not defined.")

    tuner = tune.Tuner(
        trainable=tune.with_resources(
            tune.with_parameters(TuneObjective(config,
                                               run_name,
                                               device,
                                               multi_instance_logger=multi_instance_logger
                                               )),
            resources=resources
        ),
        tune_config=tune.TuneConfig(
            metric="criterion",
            mode="min",
            num_samples=1,
            max_concurrent_trials=max_concurrent_trails
        ),
        run_config=tune.RunConfig(
            failure_config=tune.FailureConfig(fail_fast=True),
            storage_path=os.path.join(REPO_ROOT, "runs")
        ),
        param_space={key: tune.grid_search(value) for key, value in search_space.items()},
    )
    results = tuner.fit()

    multi_instance_logger.close()
