from ray import tune

from algorithms.a2c_aux import Experiment


GAME_LIST = [
    "Alien",
    "Amidar",
    "Assault",
    "Asterix",
    "Asteroids",
    "Atlantis",
    "BankHeist",
    "BattleZone",
    "BeamRider",
    "Bowling",
    "Boxing",
    "Breakout",
    "Centipede",
    "ChopperCommand",
    "CrazyClimber",
    "DemonAttack",
    "DoubleDunk",
    "Enduro",
    "FishingDerby",
    "Freeway",
    "Frostbite",
    "Gopher",
    "Gravitar",
    "Hero",
    "IceHockey",
    "Jamesbond",
    "Kangaroo",
    "Krull",
    "KungFuMaster",
    "MontezumaRevenge",
    "MsPacman",
    "NameThisGame",
    "Pong",
    "PrivateEye",
    "Qbert",
    "Riverraid",
    "RoadRunner",
    "Robotank",
    "Seaquest",
    "SpaceInvaders",
    "StarGunner",
    "Tennis",
    "TimePilot",
    "Tutankham",
    "UpNDown",
    "Venture",
    "VideoPinball",
    "WizardOfWor",
    "Zaxxon",
]


if __name__ == '__main__':
    config = {
        'label': 'mhvp-stop-gradient',

        'env_id': tune.grid_search([
            game + 'NoFrameskip-v4'
            for game in GAME_LIST
        ]),
        'env_kwargs': {},

        'torso_type': 'atari_shallow',
        'torso_kwargs': {
            'dense_layers': (),
        },
        'use_rnn': False,
        'head_layers': (512,),
        'stop_ac_grad': True,

        'nenvs': 16,
        'nsteps': 20,
        'gamma': 0.99,
        'lambda_': 1.,
        'vf_coef': 0.5,
        'entropy_reg': 0.01,

        'a2c_opt_type': 'rmsprop',
        'a2c_opt_kwargs': {
            'learning_rate': 7E-4,
            'decay': 0.99,
            'eps': 1E-5,
        },
        'max_a2c_grad_norm': 0.5,

        'aux_opt_type': 'adam',
        'aux_opt_kwargs': {
            'learning_rate': 7E-4,
            'b1': 0.,
            'b2': 0.99,
            'eps_root': 1E-5,
        },
        'max_aux_grad_norm': 0.,

        'td_net_type': 'mixed_open_loop_planning',
        'td_net_kwargs': {
            'seed': None,
            'depth': 0,
            'repeat': 0,
            'discount_factors': tuple([1. - 1. / max(1., tau) for tau in range(0, 100, 10)]),
        },
        'aux_coef': 1,
        'aux_update_freq': 1,

        'target_feature': 'reward',
        'target_feature_kwargs': {},

        'log_interval': 100,
        'seed': 42,
    }
    analysis = tune.run(
        Experiment,
        name='atari_mhvp_stop_gradient',
        config=config,
        stop={
            'num_frames': 200 * 10 ** 6,
        },
        resources_per_trial={
            'gpu': 1,
        },
    )
