from __future__ import print_function, division
import os
os.environ["OMP_NUM_THREADS"] = "1"
import argparse
import torch
import torch.multiprocessing as mp
import numpy as np
from synthetic_env import synthetic_env
from utils import read_config
from model import AClinear
from train import train
from test_ import test_, test_itr
from shared_optim import SharedSGD
#from gym.configuration import undo_logger_setup
import time

#undo_logger_setup()
parser = argparse.ArgumentParser(description='Async-TTS-AC')
parser.add_argument(
    '--lr',
    type=float,
    default=1e-2,
    metavar='LR',
    help='Initial learning rate (default: 1e-2)')
parser.add_argument(
    '--sigma1',
    type=float,
    default=0.6,
    metavar='SIG1',
    help='lr decreasing rate of actor optimizer')
parser.add_argument(
    '--sigma2',
    type=float,
    default=0.4,
    metavar='SIG2',
    help='lr decreasing rate of critic optimizer')
parser.add_argument(
    '--gamma',
    type=float,
    default=0.95,
    metavar='G',
    help='discount factor for rewards (default: 0.95)')
parser.add_argument(
    '--seed',
    type=int,
    default=1,
    metavar='S',
    help='random seed (default: 1)')
parser.add_argument(
    '--workers',
    type=int,
    default=7,
    metavar='W',
    help='how many training processes to use (default: 7)')
parser.add_argument(
    '--num-steps',
    type=int,
    default=1,
    metavar='NS',
    help='number of steps/batch size (default:1 for tts ac)')
parser.add_argument(
    '--test-episode-length',
    type=int,
    default=100,
    metavar='TEL',
    help='maximum length of an episode in test (default: 100)')
parser.add_argument(
    '--test-itr-interval',
    type=float,
    default=-1,
    metavar='TII',
    help='Iteration interval between two log entries (default: -1)')
parser.add_argument(
    '--env-config',
    default='config.json',
    metavar='EC',
    help='synthetic environment parameters (default: config.json)')
parser.add_argument(
    '--log-dir', default='logs/', metavar='LG', help='folder to save logs')
parser.add_argument(
    '--max-num-runs',
    type=int,
    default=1,
    metavar='MNR',
    help='Maximum number of independent random runs (default: 1)')
parser.add_argument(
    '--minutes-per-run',
    type=float,
    default=5.,
    metavar='MPR',
    help='Minutes per independent run (default: 5.)')
parser.add_argument(
    '--gpu-ids',
    type=int,
    default=-1,
    nargs='+',
    help='GPUs to use [-1 CPU only] (default: -1)')



if __name__ == '__main__':
    args = parser.parse_args()
    
    for num_run in range(args.max_num_runs):
        args.seed += num_run
        torch.manual_seed(args.seed)
        if args.gpu_ids == -1 or args.gpu_ids == [-1]:
            args.gpu_ids = [-1]
        else:
            torch.cuda.manual_seed(args.seed)
            if num_run == 0:
                mp.set_start_method('spawn')
        env_conf = read_config(args.env_config)
        env = synthetic_env(env_conf["state_space"], env_conf["action_space"], 
                            env_conf["state_dim"])
        shared_model = AClinear(env.state_space, env.action_space, env.state_dim)
        shared_model.share_memory()
    
        optimizers = SharedSGD([{'params': shared_model.actor_linear.parameters(),
                                 'sigma': args.sigma1},
                                {'params': shared_model.critic_linear.parameters(),
                                 'sigma': args.sigma2}], lr=args.lr)
        optimizers.share_memory()
    
        processes = []
        counter = mp.Value('i', 0)
        lock = mp.Lock()
    
        p = mp.Process(target=test_, args=(0, args, env_conf, shared_model, 
                                           counter, num_run))
        p.start()
        processes.append(p)
        time.sleep(0.1)     
        if args.test_itr_interval > 0:
            p = mp.Process(target=test_itr, args=(0, args, env_conf, shared_model, 
                                           counter, num_run))
            p.start()
            processes.append(p)
            time.sleep(0.1)
            
        for rank in range(1, args.workers+1):
            p = mp.Process(
                target=train, args=(rank, args, shared_model, optimizers, env_conf,
                                    lock, counter))
            p.start()
            processes.append(p)
            time.sleep(0.1)
        for p in processes:
            time.sleep(0.1)
            p.join()
