import pytest

# from baselines.acer import acer_simple as acer
from baselines.common.tests.envs.mnist_env import MnistEnv
from baselines.common.tests.util import simple_test
from baselines.run import get_learn_function


# TODO investigate a2c and ppo2 failures - is it due to bad hyperparameters for this problem?
# GitHub issue https://github.com/openai/baselines/issues/189
common_kwargs = {
    'seed': 0,
    'network':'cnn',
    'gamma':0.9,
    'pad':'SAME'
}

learn_args = {
    'a2c': dict(total_timesteps=50000),
    'acer': dict(total_timesteps=20000),
    'deepq': dict(total_timesteps=5000),
    'acktr': dict(total_timesteps=30000),
    'ppo2': dict(total_timesteps=50000, lr=1e-3, nsteps=128, ent_coef=0.0),
    'trpo_mpi': dict(total_timesteps=80000, timesteps_per_batch=100, cg_iters=10, lam=1.0, max_kl=0.001)
}


#tests pass, but are too slow on travis. Same algorithms are covered
# by other tests with less compute-hungry nn's and by benchmarks
@pytest.mark.skip
@pytest.mark.slow
@pytest.mark.parametrize("alg", learn_args.keys())
def test_mnist(alg):
    '''
    Test if the algorithm can learn to classify MNIST digits.
    Uses CNN policy.
    '''

    learn_kwargs = learn_args[alg]
    learn_kwargs.update(common_kwargs)

    learn = get_learn_function(alg)
    learn_fn = lambda e: learn(env=e, **learn_kwargs)
    env_fn = lambda: MnistEnv(episode_len=100)

    simple_test(env_fn, learn_fn, 0.6)

if __name__ == '__main__':
    test_mnist('acer')
