# TODO: delete this file once all experiments are fixed
import json
import multiprocessing as mp
import os.path
import sys

from sacred.serializer import restore

sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.pardir, 'src')))
sys.path.append(os.path.realpath(os.path.dirname(__file__)))
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.pardir, 'experiments')))
from grid_search import experiment as grid_exp
from experiment_utils import run_command, remove_sacred_garbage
from utils.metadata import LOG_DIRECTORY


def compute_crossval(exp_path, args):
    info_file = os.path.join(exp_path, 'info.json')
    with open(info_file, mode='r') as f:
        info = json.load(f)
    # Because of sacred and jsonpickle we need to remove garbage twice!
    info = remove_sacred_garbage(info)
    info = restore(info)
    info = remove_sacred_garbage(info)

    config_file = os.path.join(exp_path, 'config.json')
    with open(config_file, mode='r') as f:
        config = json.load(f)
    config = restore(config)

    if 'training_experiments' in info:
        train_ids = [exp['train_id'] for exp in info['training_experiments']]
    else:
        train_ids = [exp['train_id'] for exp in info['fold_0']['training_experiments']]

    config['params']['test_folds'] = args.test_folds
    config['params']['train_ids'] = train_ids
    if args.validation_metric is not None:
        config['params']['validation_metric'] = args.validation_metric
    if args.evaluation_metrics is not None:
        config['params']['evaluation_metrics'] = args.evaluation_metrics
    if args.device is not None:
        config['params']['device'] = args.device
        config['training_param_updates']['training']['device'] = args.device

    run_command(grid_exp, config_updates=config)


def main(args):
    # First read all grid_search Experiments from the log dir
    grid_search_root = os.path.join(args.log_dir, 'grid_search')
    if args.experiments is None:
        experiments = list(os.listdir(grid_search_root))
    else:
        experiments = args.experiments

    # We need to spawn a separate process for each experiment because sacred is trash
    ctx = mp.get_context('spawn')
    for entry in experiments:
        if entry.startswith('_'):
            continue

        # Run the repair function for every experiment
        exp_path = os.path.join(grid_search_root, entry)
        p = ctx.Process(target=compute_crossval, args=(exp_path, args))
        p.start()
        p.join()
        p.close()


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--log-dir', type=str, default=LOG_DIRECTORY)
    parser.add_argument('--test-folds', type=int, default=5)
    parser.add_argument('--validation-metric', type=str, default=None)
    parser.add_argument('--experiments', type=str, default=None, nargs='+')
    parser.add_argument('--evaluation-metrics', type=str, default=None, nargs='+')
    parser.add_argument('--device', type=str, default=None)

    main(parser.parse_args())
