import argparse
import psutil
from td3 import *
from utils.run_utils import setup_logger_kwargs
import copy
import os
import sys


def td3_main(args=None, redirect=True):
    if args is None:
        parser = argparse.ArgumentParser()
        parser.add_argument('--env', type=str, default='HalfCheetah-v3')
        parser.add_argument('--hid', type=int, default=256)
        parser.add_argument('--l', type=int, default=2)
        parser.add_argument('--gamma', type=float, default=0.99)
        parser.add_argument('--seed', '-s', type=int, default=0)
        parser.add_argument('--epochs', type=int, default=50)
        parser.add_argument('--exp-name', type=str, default='td3')
        parser.add_argument('--update-every', type=int, default=50)
        parser.add_argument('--steps-per-epoch', type=int, default=4000)
        parser.add_argument('--noise', type=float, default=0.1)
        parser.add_argument('--device', type=str, default=None)
        args = parser.parse_args()
        if args.device is None:
            args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if not os.path.isdir('results/td3'):
        os.makedirs('results/td3')
    if redirect:
        stdout = sys.stdout
        f = open('results/td3/' + args.exp_name + '_s' + str(args.seed) + '.txt', 'w+')
        sys.stdout = f
    else:
        f = None
        stdout = None
    logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed,
                                        os.path.dirname(os.path.realpath('__file__')) + '/results/td3')
    torch.set_num_threads(1)
    logger = td3(lambda: gym.make(args.env), actor_critic=core.MLPActorCritic,
                 ac_kwargs=dict(hidden_sizes=[args.hid] * args.l),
                 gamma=args.gamma, seed=args.seed, epochs=args.epochs,
                 logger_kwargs=logger_kwargs, steps_per_epoch=args.steps_per_epoch,
                 device=args.device, update_every=args.update_every, noise=args.noise)
    # draw_args = copy.deepcopy(args)
    # draw_args.dir_names = [os.path.join('results/td3', args.exp_name)]
    # draw_args.exp_names = ['TD3']
    # draw_args.save_dir = os.path.join('results/td3', args.exp_name, args.exp_name + '_s' + str(args.seed))
    # draw_args.target_keys = ['AverageTestEpRet', 'LossQ', 'LossPi']
    # draw_args.filename = 'progress.txt'
    # draw_args.prefix = ''
    # from draw import draw
    # draw(draw_args)
    if redirect:
        f.close()
        sys.stdout = stdout
    return logger


if __name__ == '__main__':
    td3_main(redirect=True)
