import datetime
import os

import ray
from ray import tune

import numpy as np
import tensorflow as tf
import tree


def datetime_stamp(divider='-', datetime_divider='T'):
    now = datetime.datetime.now()
    return now.strftime(
        '%Y{d}%m{d}%dT%H{d}%M{d}%S'
        ''.format(d=divider, dtd=datetime_divider))


def set_seeds(seed):
    import tensorflow as tf
    import numpy as np
    np.random.seed(seed)
    tf.random.set_seed(seed)


def set_gpu_memory_growth(growth):
    import tensorflow as tf
    gpus = tf.config.experimental.list_physical_devices('GPU')

    if gpus:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, growth)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus),
              "Logical GPUs")


class ExperimentRunner(tune.Trainable):
    def _setup(self, variant):
        # Set the current working directory such that the local mode
        # logs into the correct place. This would not be needed on
        # local/cluster mode.
        if ray.worker._mode() == ray.worker.LOCAL_MODE:
            os.chdir(os.getcwd())

        set_gpu_memory_growth(True)

        set_seeds(variant['run_params']['seed'])

        if variant['run_params'].get('run_eagerly', False):
            tf.config.experimental_run_functions_eagerly(True)

        self._variant = variant

    def _train(self):
        epoch_length = self._variant['experiment_params']['epoch_length']
        result = self.task.step(num_steps=epoch_length)
        result = tree.map_structure(np.array, result)
        training_steps = (self.iteration + 1) * epoch_length
        result.update({
            ray.tune.result.DONE: (
                self.config['experiment_params']['total_samples']
                <= training_steps),
            'training_steps': training_steps,
        })
        return result
