import collections.abc
import copy
import functools
import os.path
import sys
from concurrent.futures import ProcessPoolExecutor as Pool

import sacred.utils
import torch
import torch.utils.data
# from torch.multiprocessing import Pool

sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.pardir, 'src')))
sys.path.append(os.path.realpath(os.path.dirname(__file__)))
from experiment_utils import make_experiment, data_ingredient, run_command, serialization_guard, remove_sacred_garbage
from utils.utils import Bunch, param_grid_to_list_of_dicts, str2cls
from utils.torch_utils import set_threads


experiment = make_experiment(ingredients=[data_ingredient])


def train_once(config_updates, params):
    set_threads(params.threads_per_process)

    # Run the training experiment to get a trained model
    train_exp = str2cls(f'{params.training_experiment}.experiment')
    print('training with parameters', config_updates)
    train_run = run_command(train_exp, config_updates=config_updates)

    return train_run._id


@experiment.config
def config():
    # General Evaluation parameters
    params = dict(
        # The training experiment that will be executed. Please provide a relative path from the experiments and
        # use '.' as the separator
        training_experiment='reconstruction.train_lstm_ae',
        # This will be computed over the validation set and used to determine the best parameters in the grid
        validation_metric='best_f1_score',
        # These metrics will be calculated on the test set
        evaluation_metrics=['best_f1_score', 'auprc'],
        batch_size=128,
        device='cpu',
        # Number of processes to train in parallel and CPU threads used for each process
        exp_processes=1,
        threads_per_process=1,
        test_folds=5,
        train_ids=None,
        padding=1
    )

    # Use this to specify training parameters that should be searched over. They will take precedence over the
    # values specified in param_updates. The format for this is that you define a list of
    # possible values for each attribute. For example:
    # training_param_grid = dict(
    #     model_params=dict(
    #         hidden_dimensions=[[40], [50]]
    #     ),
    #     training=dict(
    #         optimizer = {
    #             'args': dict(lr=[1e-3, 0.01])
    #         }
    #     )
    # )
    training_param_grid = dict()

    # Use this to overwrite training parameters with a single value. Will be overwritten with the values in
    # training_param_grid
    training_param_updates = dict()

    # This should contain all values for the detector parameters that should be searched over.
    # Note that this should contain only parameters of the detector, for which no retraining is needed.
    detector_param_grid = dict()


@data_ingredient.config
def data_config():
    ds_args = dict(
        training=False
    )
    split = (0.3, 0.7)


@experiment.automain
@serialization_guard
def main(params, training_param_grid, training_param_updates, detector_param_grid, dataset, _run):
    params = dict(params)

    # Take batch dimension from training experiment
    params['batch_dim'] = str2cls(f'{params["training_experiment"]}.get_batch_dim')()
    params = Bunch(params)

    # set_seed(seed)
    # run_deterministic()
    # Each worker process gets assigned two threads, so we can use them all in the main evaluation
    set_threads(params.threads_per_process * params.exp_processes)

    # Load the dataset with the params from the grid search experiment, i.e. the val set
    # Split dataset into validation and test set
    # Default Pipeline is taken from the training experiment
    pipeline = str2cls(f'{params.training_experiment}.get_test_pipeline')()
    if isinstance(dataset['pipeline'], collections.abc.Sequence):
        pipeline = [sacred.utils.recursive_update(copy.deepcopy(pipeline), pipe) for pipe in dataset['pipeline']]
    else:
        sacred.utils.recursive_update(pipeline, dataset['pipeline'])
    dataset = remove_sacred_garbage(dataset)
    dataset['pipeline'] = pipeline

    training_param_grid = list(param_grid_to_list_of_dicts(training_param_grid))
    updated_param_grid = [sacred.utils.recursive_update(copy.deepcopy(training_param_updates), point)
                          for point in training_param_grid]
    detector_param_grid = list(param_grid_to_list_of_dicts(detector_param_grid))

    # Note: We need to use spawn here, otherwise CUDA will produce errors
    context = torch.multiprocessing.get_context('spawn')

    # Only train if no information on already finished training runs is provided
    train_proc_func = functools.partial(train_once, params=params)
    with Pool(max_workers=params.exp_processes, mp_context=context) as pool:
        train_ids = list(pool.map(train_proc_func, updated_param_grid))

    return train_ids
