from __future__ import print_function, division
import os
os.environ["OMP_NUM_THREADS"] = "1"
import argparse
import torch
import torch.multiprocessing as mp
from environment import atari_env
from utils import read_config
from model import AClstm
from train import train
from test_ import test_, test_itr
from shared_optim import SharedRMSprop, SharedAdam, SharedSGD
#from gym.configuration import undo_logger_setup
import time

#undo_logger_setup()
parser = argparse.ArgumentParser(description='Async-AC')
parser.add_argument(
    '--lr',
    type=float,
    default=0.00015,
    metavar='LR',
    help='learning rate (default: 0.00015)')
parser.add_argument(
    '--gamma',
    type=float,
    default=0.99,
    metavar='G',
    help='discount factor for rewards (default: 0.99)')
parser.add_argument(
    '--seed',
    type=int,
    default=1,
    metavar='S',
    help='random seed (default: 1)')
parser.add_argument(
    '--workers',
    type=int,
    default=16,
    metavar='W',
    help='how many training processes to use (default: 16)')
parser.add_argument(
    '--num-steps',
    type=int,
    default=20,
    metavar='NS',
    help='number of forward steps in A3C (default: 20)')
parser.add_argument(
    '--max-episode-length',
    type=int,
    default=10000,
    metavar='M',
    help='maximum length of an episode (default: 10000)')
parser.add_argument(
    '--env',
    default='Pong-v0',
    metavar='ENV',
    help='environment to train on (default: Pong-v0)')
parser.add_argument(
    '--env-config',
    default='config.json',
    metavar='EC',
    help='environment to crop and resize info (default: config.json)')
parser.add_argument(
    '--load', default=False, metavar='L', help='load a trained model')
parser.add_argument(
    '--save-max',
    default=True,
    metavar='SM',
    help='Save model on every test run high score matched or bested')
parser.add_argument(
    '--optimizer',
    default='Adam',
    metavar='OPT',
    help='shares optimizer choice of Adam or RMSprop')
# parser.add_argument(
#     '--clip-grad-norm',
#     type=float,
#     default=0.5,
#     metavar='CGN',
#     help='Ratio of gradient norm clipped ')
parser.add_argument(
    '--test-itr-interval',
    type=float,
    default=-1,
    metavar='TII',
    help='Interval between entries of iteration log (default: no itreration log)')
parser.add_argument(
    '--load-model-dir',
    default='trained_models/',
    metavar='LMD',
    help='folder to load trained models from')
parser.add_argument(
    '--save-model-dir',
    default='trained_models/',
    metavar='SMD',
    help='folder to save trained models')
parser.add_argument(
    '--log-dir', default='logs/', metavar='LG', help='folder to save logs')
parser.add_argument(
    '--max-num-runs',
    type=int,
    default=1,
    metavar='MNR',
    help='Maximum number of independent random runs (default: 1)')
parser.add_argument(
    '--minutes-per-run',
    type=float,
    default=60.,
    metavar='MPR',
    help='Minutes per independent run (default: 5.)')
parser.add_argument(
    '--gpu-ids',
    type=int,
    default=-1,
    nargs='+',
    help='GPUs to use [-1 CPU only] (default: -1)')
parser.add_argument(
    '--amsgrad',
    default=False,
    metavar='AM',
    help='Adam optimizer amsgrad parameter')
parser.add_argument(
    '--skip-rate',
    type=int,
    default=4,
    metavar='SR',
    help='frame skip rate (default: 4)')


if __name__ == '__main__':
    args = parser.parse_args()
    max_score = 0
    for num_run in range(args.max_num_runs):
        args.seed += num_run
        torch.manual_seed(args.seed)
        if args.gpu_ids == -1 or args.gpu_ids == [-1]:
            args.gpu_ids = [-1]
        else:
            torch.cuda.manual_seed(args.seed)
            if num_run == 0:
                mp.set_start_method('spawn')
        setup_json = read_config(args.env_config)
        env_conf = setup_json["Default"]
        for i in setup_json.keys():
            if i in args.env:
                env_conf = setup_json[i]
        env = atari_env(args.env, env_conf, args)
        
        shared_model = AClstm(env.observation_space.shape[0], env.action_space)
        if args.load:
            saved_state = torch.load(
                '{0}{1}.dat'.format(args.load_model_dir, args.env),
                map_location=lambda storage, loc: storage)
            shared_model.load_state_dict(saved_state)
        shared_model.share_memory()
    
        if args.optimizer == 'RMSprop':
            optimizer = SharedRMSprop(shared_model.parameters(), lr=args.lr)
        elif args.optimizer == 'Adam':
            optimizer = SharedAdam(
                shared_model.parameters(), lr=args.lr, amsgrad=args.amsgrad)
        elif args.optimizer == 'SGD':
            base_params = list(shared_model.conv1.parameters())\
            +list(shared_model.conv2.parameters())+list(shared_model.conv3.parameters())\
            +list(shared_model.conv4.parameters())+list(shared_model.lstm.parameters())
            optimizer = SharedSGD([{'params': base_params, 'sigma': 0},
                                   {'params': shared_model.actor_linear.parameters(), 'sigma': 0.6},
                                   {'params': shared_model.critic_linear.parameters(), 'sigma': 0.4}], 
                                   lr=args.lr)
        else:
            raise Exception('The optimizer is not implemented.')
        optimizer.share_memory()
    
        counter = mp.Value('i', 0)
        lock = mp.Lock()
        processes = []
    
        p = mp.Process(target=test_, args=(0, args, env_conf, shared_model,
                                           counter, num_run, max_score))
        p.start()
        processes.append(p)
        time.sleep(0.1)
        if args.test_itr_interval > 0:
            p = mp.Process(target=test_itr, args=(0, args, env_conf, shared_model, 
                                                  counter, num_run))
            p.start()
            processes.append(p)
            time.sleep(0.1)
        for rank in range(1, args.workers+1):
            p = mp.Process(
                target=train, args=(rank, args, shared_model, optimizer, 
                                    env_conf, lock, counter))
            p.start()
            processes.append(p)
            time.sleep(0.1)
        for p in processes:
            time.sleep(0.1)
            p.join()
