from garage.torch import set_gpu_mode
import torch
from garage import wrap_experiment
from garage.envs import GymEnv
from garage.experiment.deterministic import set_seed
from garage.sampler import RaySampler
from garage.torch.algos import VPG, TRPO, HSODM
from garage.torch.policies import GaussianMLPPolicy
from garage.torch.value_functions import GaussianMLPValueFunction
from garage.trainer import Trainer
import numpy as np
from garage.torch.optimizers import OptimizerWrapper
from garage.torch.optimizers.HSODM_optimizer import HSODMOptimizer
from garage.torch.optimizers.SGD_optimizer import SGD
from garage.np.baselines import LinearFeatureBaseline
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--ls", type=str, default='backtrack', choices=['const', 'backtrack'],
                    help="step size mode")
parser.add_argument("--ls_const_delta", type=float, default=0.05, help="step size if const")
parser.add_argument("--order", type=int, default=1, help="order of negative curvature")
parser.add_argument("--delta", type=float, default=1e-4, help="lower right element")
parser.add_argument("--seed", type=int, default=1, help="random seed")

args = parser.parse_args()
print(args)

class HomogeneousParam(object):
    def __init__(self, **kwargs):
        self.Delta = kwargs.get("Delta", 0.05)
        self.line_search = kwargs.get("linesearch", "const")
        self.order = kwargs.get("order", 1)
        self.delta = kwargs.get("delta", 1e-4)

    @classmethod
    def parse(cls, args):
        return HomogeneousParam(
            Delta=args.ls_const_delta,
            linesearch=args.ls,
            order=args.order,
            delta=args.delta
        )
    
homogeneous_param = HomogeneousParam.parse(args)


@wrap_experiment(archive_launch_repo=False,
                    log_dir="halfcheeta_hsodm_seed={}order={}Delta={}linesearch={}delta={}".format(args.seed, homogeneous_param.order, homogeneous_param.Delta, homogeneous_param.line_search, homogeneous_param.delta))
def hsodm_halfcheeta(ctxt=None, seed=args.seed):
        """Train PPO with InvertedDoublePendulum-v2 environment.

        Args:
            ctxt (garage.experiment.ExperimentContext): The experiment
                configuration used by Trainer to create the snapshotter.
            seed (int): Used to seed the random number generator to produce
                determinism.

        """
        n_epochs = 1000
        sampler_batch_size = 10000

        set_seed(seed)
        env = GymEnv('HalfCheetah-v2')

        trainer = Trainer(ctxt)

        policy = GaussianMLPPolicy(env.spec,
                                hidden_sizes=[64, 64],
                                hidden_nonlinearity=torch.tanh,
                                output_nonlinearity=None)

        value_function = LinearFeatureBaseline(env_spec=env.spec)

        sampler = RaySampler(agents=policy,
                            envs=env,
                            max_episode_length=500)

        policy_optimizer = OptimizerWrapper((HSODMOptimizer, {
            "homogeneous_param": homogeneous_param,
        }), policy)

        algo = HSODM(env_spec=env.spec,
                    policy=policy,
                    value_function=value_function,
                    sampler=sampler,
                    discount=0.99,
                    center_adv=False,
                    policy_optimizer=policy_optimizer,
                    neural_baseline=False,
                    )

        trainer.setup(algo, env)
        trainer.train(n_epochs=n_epochs, batch_size=sampler_batch_size)


hsodm_halfcheeta()



