import argparse
import psutil
from sac import *
from noda import *
from utils.run_utils import setup_logger_kwargs
import copy
import os
import sys


def sac_main(args=None, redirect=True):
    if args is None:
        parser = argparse.ArgumentParser()
        parser.add_argument('--env', type=str, default='HalfCheetah-v3')
        parser.add_argument('--hid', type=int, default=256)
        parser.add_argument('--l', type=int, default=2)
        parser.add_argument('--gamma', type=float, default=0.99)
        parser.add_argument('--seed', '-s', type=int, default=0)
        parser.add_argument('--epochs', type=int, default=50)
        parser.add_argument('--exp-name', type=str, default='sac')
        parser.add_argument('--update-every', type=int, default=50)
        parser.add_argument('--steps-per-epoch', type=int, default=4000)
        parser.add_argument('--noise', type=float, default=0.1)
        parser.add_argument('--device', type=str, default=None)
        args = parser.parse_args()
        if args.device is None:
            args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if not os.path.isdir('results/sac'):
        os.makedirs('results/sac')
    if redirect:
        stdout = sys.stdout
        f = open('results/sac/' + args.exp_name + '_s' + str(args.seed) + '.txt', 'w+')
        sys.stdout = f
    else:
        f = None
        stdout = None
    logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed,
                                        os.path.dirname(os.path.realpath('__file__')) + '/results/sac')
    torch.set_num_threads(1)
    logger = sac(lambda: gym.make(args.env), actor_critic=core.MLPActorCritic,
                 ac_kwargs=dict(hidden_sizes=[args.hid] * args.l),
                 gamma=args.gamma, seed=args.seed, epochs=args.epochs,
                 logger_kwargs=logger_kwargs, steps_per_epoch=args.steps_per_epoch,
                 device=args.device, update_every=args.update_every, noise=args.noise)
    if redirect:
        f.close()
        sys.stdout = stdout
    return logger


if __name__ == '__main__':
    sac_main(redirect=True)
