from ray import tune

from algorithms.prediction import Experiment


if __name__ == '__main__':
    config = {
        'label': 'empty-room-random-feature-rgvfs',

        'env_kwargs': {
            'size': 9,
            'observation_type': 'one_hot',
        },

        'torso_type': 'maze_shallow',
        'torso_kwargs': {
            'conv_layers': (),
            'dense_layers': (64, 64, 32),
        },
        'head_layers': (32,),
        'stop_value_grad': True,

        'nenvs': 8,
        'nsteps': 8,
        'max_frames': 1 * 10 ** 6,
        'gamma': 0.98,

        'pe_opt_type': 'adam',
        'pe_opt_kwargs': {
            'learning_rate': 1E-3,
        },
        'max_pe_grad_norm': 0.,

        'target_feature': 'random_feature',
        'target_feature_kwargs': {
            'conv_layers': (),
            'dense_layers': tune.grid_search([(1,), (4,), (16,), (64,)]),
            'padding': 'VALID',
            'w_init': 'orthogonal',
            'w_init_scale': 1.,
            'delta': True,
            'absolute': True,
            'only_last_channel': False,
        },

        'td_net_type': 'mixed_open_loop_planning',
        'td_net_kwargs': {
            'seed': None,
            'depth': 4,
            'repeat': tune.sample_from(lambda spec: spec.config.target_feature_kwargs['dense_layers'][0]),
            'discount_factors': (0.8,),
        },

        'aux_opt_type': 'adam',
        'aux_opt_kwargs': {
            'learning_rate': 1E-3,
        },
        'max_aux_grad_norm': 0.,
        'aux_update_freq': 1,

        'log_interval': 100,
        'seed': 42,
    }
    analysis = tune.run(
        Experiment,
        name='empty_room_random_feature_rgvfs',
        config=config,
        stop={
            'num_frames': 1 * 10 ** 6,
        },
        resources_per_trial={
            'cpu': 1,
        },
    )
