from ray import tune

from algorithms.a2c_aux import Experiment


GAME_LIST = [
    "BeamRider",
    "Breakout",
    "Pong",
    "Qbert",
    "Seaquest",
    "SpaceInvaders",
]

NUM_ACTIONS = {
    'BeamRider': 9,
    'Breakout': 4,
    'Pong': 6,
    'Qbert': 6,
    'Seaquest': 18,
    'SpaceInvaders': 6,
}


if __name__ == '__main__':
    config = {
        'label': 'shallow',

        'env_id': tune.grid_search([
            game + 'NoFrameskip-v4'
            for game in GAME_LIST
        ]),
        'env_kwargs': {},

        'torso_type': 'atari_shallow',
        'torso_kwargs': {
            'dense_layers': (),
        },
        'use_rnn': False,
        'head_layers': (512,),
        'stop_ac_grad': True,

        'nenvs': 16,
        'nsteps': 20,
        'gamma': 0.99,
        'lambda_': 1.,
        'vf_coef': 0.5,
        'entropy_reg': 0.01,

        'a2c_opt_type': 'rmsprop',
        'a2c_opt_kwargs': {
            'learning_rate': 7E-4,
            'decay': 0.99,
            'eps': 1E-5,
        },
        'max_a2c_grad_norm': 0.5,

        'aux_opt_type': 'adam',
        'aux_opt_kwargs': {
            'learning_rate': 7E-4,
            'b1': 0.,
            'b2': 0.99,
            'eps_root': 1E-5,
        },
        'max_aux_grad_norm': 0.,

        'td_net_type': 'mixed_open_loop_planning',
        'td_net_kwargs': {
            'seed': None,
            'depth': 0,
            'repeat': 0,
            'discount_factors': (0.,),
        },
        'aux_coef': 1,
        'aux_update_freq': 1,

        'target_feature': 'random_feature',
        'target_feature_kwargs': {
            'conv_layers': tune.sample_from(
                lambda spec: ((8 * NUM_ACTIONS[spec.config.env_id[:-14]] + 1, 21, 21),)),
            'dense_layers': (),
            'padding': 'VALID',
            'w_init': 'orthogonal',
            'w_init_scale': 8.,
            'delta': True,
            'absolute': True,
            'only_last_channel': True,
        },

        'log_interval': 100,
        'seed': 42,
    }
    analysis = tune.run(
        Experiment,
        name='atari6_shallow',
        config=config,
        stop={
            'num_frames': 200 * 10 ** 6,
        },
        resources_per_trial={
            'gpu': 1,
        },
    )
