from ray import tune

from algorithms.prediction import Experiment


if __name__ == '__main__':
    config = {
        'label': 'empty-room-touch-tree',

        'env_kwargs': {
            'size': 9,
            'observation_type': 'one_hot',
        },

        'torso_type': 'maze_shallow',
        'torso_kwargs': {
            'conv_layers': (),
            'dense_layers': (64, 64, 32),
        },
        'head_layers': (32,),
        'stop_value_grad': True,

        'nenvs': 8,
        'nsteps': 8,
        'max_frames': 1 * 10 ** 6,
        'gamma': 0.98,

        'pe_opt_type': 'adam',
        'pe_opt_kwargs': {
            'learning_rate': 1E-3,
        },
        'max_pe_grad_norm': 0.,

        'target_feature': 'touch',
        'target_feature_kwargs': {},

        'td_net_type': 'cond_tree_sum',
        'td_net_kwargs': {
            'depth': tune.grid_search([1, 2, 3, 4]),
            'balance_by_depth': False,
        },

        'aux_opt_type': 'adam',
        'aux_opt_kwargs': {
            'learning_rate': 1E-3,
        },
        'max_aux_grad_norm': 0.,
        'aux_update_freq': 1,

        'log_interval': 100,
        'seed': 42,
    }
    analysis = tune.run(
        Experiment,
        name='empty_room_touch_tree',
        config=config,
        stop={
            'num_frames': 1 * 10 ** 6,
        },
        resources_per_trial={
            'cpu': 1,
        },
    )
