from garage import wrap_experiment
from garage.envs import GymEnv
from garage.experiment.deterministic import set_seed
from garage.np.baselines import LinearFeatureBaseline
from garage.sampler import RaySampler
from garage.torch.algos import SGDH
from garage.torch.optimizers import OptimizerWrapper
from garage.torch.optimizers.SGD_optimizer import SGD
from garage.torch.policies import GaussianMLPPolicy
from garage.trainer import Trainer
import torch

step_size = 0.01


def run_task(seed):
    @wrap_experiment(archive_launch_repo=False,
                     log_dir="/root/Data/icml/final/swimmer_sgd-lr{}seed={}".format(step_size,seed))
    def sgd_swimmer(ctxt=None, seed=1):
        n_epochs = 1000
        sampler_batch_size = 10000

        set_seed(seed)
        env = GymEnv('Swimmer-v2')
        env._env.seed(seed)
        env.action_space.seed(seed)
        trainer = Trainer(ctxt)

        policy = GaussianMLPPolicy(env.spec,
                                   hidden_sizes=[64, 64],
                                   hidden_nonlinearity=torch.tanh,
                                   output_nonlinearity=None)

        # value_function = GaussianMLPValueFunction(env_spec=env.spec,
        #                                           hidden_sizes=(32, 32),
        #                                           hidden_nonlinearity=torch.tanh,
        #                                           output_nonlinearity=None)
        value_function = LinearFeatureBaseline(env_spec=env.spec)

        sampler = RaySampler(agents=policy,
                             envs=env,
                             max_episode_length=500)

        policy_optimizer = OptimizerWrapper((SGD, {
            "lr": step_size,
        }), policy)

        algo = SGDH(env_spec=env.spec,
                    policy=policy,
                    value_function=value_function,
                    sampler=sampler,
                    discount=0.99,
                    center_adv=False,
                    policy_optimizer=policy_optimizer,
                    neural_baseline=False,
                    )

        trainer.setup(algo, env)
        trainer.train(n_epochs=n_epochs, batch_size=sampler_batch_size)

    sgd_swimmer(seed=seed)


seeds = [14, 33, 3, 4, 49, ]
for seed in seeds:
    run_task(seed=seed)
