# Created by user3, 25/03/2018
import sys
import argparse
import os.path as osp
import datetime
import dateutil.tz
import ast
import uuid
import pickle as pickle
import base64
import joblib

from chester import config


def run_experiment(argv):
    default_log_dir = config.LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_parallel', type=int, default=1,
                        help='Number of parallel workers to perform rollouts. 0 => don\'t start any workers')
    parser.add_argument(
        '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')
    parser.add_argument('--log_dir', type=str, default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode', type=str, default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                             '(all iterations will be saved), "last" (only '
                             'the last iteration will be saved), "gap" (every'
                             '`snapshot_gap` iterations are saved), or "none" '
                             '(do not save snapshots)')
    parser.add_argument('--snapshot_gap', type=int, default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file', type=str, default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file', type=str, default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file', type=str, default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument('--resume_from', type=str, default=None,
                        help='Name of the pickle file to resume experiment from.')
    parser.add_argument('--plot', type=ast.literal_eval, default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
                        help='Whether to only print the tabular log information (in a horizontal format)')
    parser.add_argument('--seed', type=int,
                        help='Random seed for numpy')
    parser.add_argument('--args_data', type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--variant_data', type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False)

    args = parser.parse_args(argv[1:])

    # if args.seed is not None:
    #     set_seed(args.seed)
    #
    # if args.n_parallel > 0:
    #     from rllab.sampler import parallel_sampler
    #     parallel_sampler.initialize(n_parallel=args.n_parallel)
    #     if args.seed is not None:
    #         parallel_sampler.set_seed(args.seed)
    #
    # if args.plot:
    #     from rllab.plotter import plotter
    #     plotter.init_worker()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    # tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    # text_log_file = osp.join(log_dir, args.text_log_file)
    # params_log_file = osp.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        # logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    # if not args.use_cloudpickle:
    #     logger.log_parameters_lite(params_log_file, args)
    #
    # logger.add_text_output(text_log_file)
    # logger.add_tabular_output(tabular_log_file)
    # prev_snapshot_dir = logger.get_snapshot_dir()
    # prev_mode = logger.get_snapshot_mode()
    # logger.set_snapshot_dir(log_dir)
    # logger.set_snapshot_mode(args.snapshot_mode)
    # logger.set_snapshot_gap(args.snapshot_gap)
    # logger.set_log_tabular_only(args.log_tabular_only)
    # logger.push_prefix("[%s] " % args.exp_name)

    if args.resume_from is not None:
        data = joblib.load(args.resume_from)
        assert 'algo' in data
        algo = data['algo']
        algo.train()
    else:
        # read from stdin
        if args.use_cloudpickle:
            import cloudpickle
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            method_call(variant_data, log_dir, args.exp_name)
        else:
            assert False
            # data = pickle.loads(base64.b64decode(args.args_data))
            # maybe_iter = concretize(data)
            # if is_iterable(maybe_iter):
            #     for _ in maybe_iter:
            #         pass

            # logger.set_snapshot_mode(prev_mode)
            # logger.set_snapshot_dir(prev_snapshot_dir)
            # logger.remove_tabular_output(tabular_log_file)
            # logger.remove_text_output(text_log_file)
            # logger.pop_prefix()


if __name__ == "__main__":
    run_experiment(sys.argv)
