from __future__ import print_function, division
import os
os.environ["OMP_NUM_THREADS"] = "1"
import argparse
import torch
import torch.multiprocessing as mp
import gym
from gym_extensions.continuous import gym_navigation_2d
from utils import read_config
from model import ACMLP
from train import train
from test_ import test_, test_itr
from shared_optim import SharedSGD, SharedAdam
#from gym.configuration import undo_logger_setup
import time

#undo_logger_setup()
parser = argparse.ArgumentParser(description='Async-TTS-AC')
parser.add_argument(
    '--lr',
    type=float,
    default=5e-3,
    metavar='LR',
    help='Initial learning rate (default: 0.005)')
parser.add_argument(
    '--optimizer',
    default='SGD',
    metavar='OPT',
    help='what optimizer to use (default: SGD)')
parser.add_argument(
    '--amsgrad',
    default=False,
    metavar='AMS',
    help='whether to use amsgrad for Adam (default: False)')
parser.add_argument(
    '--gamma',
    type=float,
    default=0.95,
    metavar='G',
    help='discount factor for rewards (default: 0.95)')
parser.add_argument(
    '--seed',
    type=int,
    default=1,
    metavar='S',
    help='random seed (default: 1)')
parser.add_argument(
    '--workers',
    type=int,
    default=7,
    metavar='W',
    help='how many training processes to use (default: 7)')
parser.add_argument(
    '--num-steps',
    type=int,
    default=32,
    metavar='NS',
    help='number of steps/batch size (default:32)')
parser.add_argument(
    '--max-grad-norm',
    type=float,
    default=50.,
    metavar='MGN',
    help='Threshold of gradient norm for gradient clipping (default: 50.)')
parser.add_argument(
    '--test-itr-interval',
    type=float,
    default=-1,
    metavar='TII',
    help='Interval between entries of iteration log (default: no itreration log)')
parser.add_argument(
    '--env-name',
    default='CartPole-v0',
    metavar='EN',
    help='environment name (default: CartPole-v0)')
parser.add_argument(
    '--log-dir', default='logs/', metavar='LG', help='folder to save logs')
parser.add_argument(
    '--max-num-runs',
    type=int,
    default=1,
    metavar='MNR',
    help='Maximum number of independent random runs (default: 1)')
parser.add_argument(
    '--minutes-per-run',
    type=float,
    default=5.,
    metavar='MPR',
    help='Minutes per independent run (default: 5.)')
parser.add_argument(
    '--gpu-ids',
    type=int,
    default=-1,
    nargs='+',
    help='GPUs to use [-1 CPU only] (default: -1)')



if __name__ == '__main__':
    args = parser.parse_args()
    
    for num_run in range(args.max_num_runs):
        args.seed += num_run
        torch.manual_seed(args.seed)
        if args.gpu_ids == -1 or args.gpu_ids == [-1]:
            args.gpu_ids = [-1]
        else:
            torch.cuda.manual_seed(args.seed)
            if num_run == 0:
                mp.set_start_method('spawn')
        
        env = gym.make(args.env_name)
        shared_model = ACMLP(env.action_space.n, env.observation_space.shape[0])
        shared_model.share_memory()
        
        if args.optimizer == 'SGD':
            base_params = list(shared_model.linear1.parameters())+list(shared_model.linear2.parameters())
            optimizers = SharedSGD([
                                    {'params': base_params,'sigma': 0.},
                                    {'params': shared_model.actor_linear.parameters(),'sigma': 0.6},
                                    {'params': shared_model.critic_linear.parameters(),'sigma': 0.4}
                                    ], 
                                    lr=args.lr)
        elif args.optimizer == 'Adam':
            optimizers = SharedAdam(shared_model.parameters(), lr=args.lr, amsgrad=args.amsgrad)
        elif args.optimizer == 'RMSprop':
            optimizers = SharedRMSprop(shared_model.parameters(), lr=args.lr)
        else:
            raise Exception('The optimizer is not implemented.')
        optimizers.share_memory()
    
        processes = []
        counter = mp.Value('i', 0)
        lock = mp.Lock()
    
        p = mp.Process(target=test_, args=(0, args, shared_model, 
                                           counter, num_run))
        p.start()
        processes.append(p)
        time.sleep(0.1)     
        if args.test_itr_interval > 0:
            p = mp.Process(target=test_itr, args=(0, args, shared_model, 
                                           counter, num_run))
            p.start()
            processes.append(p)
            time.sleep(0.1)
            
        for rank in range(1, args.workers+1):
            p = mp.Process(
                target=train, args=(rank, args, shared_model, optimizers,
                                    lock, counter))
            p.start()
            processes.append(p)
            time.sleep(0.1)
        for p in processes:
            time.sleep(0.1)
            p.join()
