import time
import functools
import tensorflow as tf

#from baselines import logger
from baselines.logger import configure

from baselines.common import set_global_seeds, explained_variance
from baselines.common import tf_util
from baselines.common.policies import build_policy


from baselines.a2c.utils import Scheduler, find_trainable_variables
from baselines.a2c.runner_exp2 import Runner
from baselines.ppo2.ppo2 import safemean
from collections import deque

from tensorflow import losses

import sys
import multiprocessing
from baselines.common.cmd_util import make_vec_env, make_env
from baselines.common.tf_util import get_session
from baselines.common.vec_env import VecFrameStack, VecNormalize
import gym
import re
from collections import defaultdict
from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder

from importlib import import_module
import os.path as osp

import math

import os

from mpi4py import MPI

import gc
from array import array
import numpy

#import mpi4py
#mpi4py.rc.recv_mprobe = False

# Get number of cores reserved by the batch system (NSLOTS is automatically set, or use 1 if not)
NUMCORES=int(os.getenv("NSLOTS",1))
print("Using", NUMCORES, "core(s)" )

#tf.config.threading.set_inter_op_parallelism_threads(NUMCORES) 
#tf.config.threading.set_intra_op_parallelism_threads(NUMCORES)
#tf.config.set_soft_device_placement(1)

def get_alg_module(alg, submodule=None):
    submodule = submodule or alg
    try:
        # first try to import the alg module from baselines
        alg_module = import_module('.'.join(['baselines', alg, submodule]))
    except ImportError:
        # then from rl_algs
        alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule]))

    return alg_module

def get_learn_function_defaults(alg, env_type):
    try:
        alg_defaults = get_alg_module(alg, 'defaults')
        kwargs = getattr(alg_defaults, env_type)()
    except (ImportError, AttributeError):
        kwargs = {}
    return kwargs

def get_default_network(env_type):
    if env_type in {'atari', 'retro'}:
        return 'cnn'
    else:
        return 'mlp'

def constfn(val):
    def f(_):
        return val
    return f

_game_envs = defaultdict(set)
for env in gym.envs.registry.all():
    # TODO: solve this with regexes
    env_type = env.entry_point.split(':')[0].split('.')[-1]
    _game_envs[env_type].add(env.id)

# reading benchmark names directly from retro requires
# importing retro here, and for some reason that crashes tensorflow
# in ubuntu
_game_envs['retro'] = {
    'BubbleBobble-Nes',
    'SuperMarioBros-Nes',
    'TwinBee3PokoPokoDaimaou-Nes',
    'SpaceHarrier-Nes',
    'SonicTheHedgehog-Genesis',
    'Vectorman-Genesis',
    'FinalFight-Snes',
    'SpaceInvaders-Snes',
}

def get_env_type(args):
    env_id = args.env

    if args.env_type is not None:
        return args.env_type, env_id

    # Re-parse the gym registry, since we could have new envs since last time.
    for env in gym.envs.registry.all():
        env_type = env.entry_point.split(':')[0].split('.')[-1]
        _game_envs[env_type].add(env.id)  # This is a set so add is idempotent

    if env_id in _game_envs.keys():
        env_type = env_id
        env_id = [g for g in _game_envs[env_type]][0]
    else:
        env_type = None
        for g, e in _game_envs.items():
            if env_id in e:
                env_type = g
                break
        if ':' in env_id:
            env_type = re.sub(r':.*', '', env_id)
        assert env_type is not None, 'env_id {} is not recognized in env types'.format(env_id, _game_envs.keys())

    return env_type, env_id

def build_env(args):
    ncpu = multiprocessing.cpu_count()
    if sys.platform == 'darwin': ncpu //= 2
    nenv = args.num_env or ncpu
    alg = args.alg
    seed = args.seed

    env_type, env_id = get_env_type(args)

    if env_type in {'atari', 'retro'}:
        if alg == 'deepq':
            env = make_env(env_id, env_type, seed=seed, wrapper_kwargs={'frame_stack': True})
        elif alg == 'trpo_mpi':
            env = make_env(env_id, env_type, seed=seed)
        else:
            frame_stack_size = 4
            env = make_vec_env(env_id, env_type, nenv, seed, gamestate=args.gamestate, reward_scale=args.reward_scale)
            env = VecFrameStack(env, frame_stack_size)

    else:
        config = tf.ConfigProto(allow_soft_placement=True,
                               intra_op_parallelism_threads=1,
                               inter_op_parallelism_threads=1)
        config.gpu_options.allow_growth = True
        get_session(config=config)

        flatten_dict_observations = alg not in {'her'}
        env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale, flatten_dict_observations=flatten_dict_observations)

        if env_type == 'mujoco':
            env = VecNormalize(env, use_tf=True)

    return env

class Model(object):

    """
    We use this class to :
        __init__:
        - Creates the step_model
        - Creates the train_model

        train():
        - Make the training part (feedforward and retropropagation of gradients)

        save/load():
        - Save load the model
    """
    def __init__(self, model_type, policy, env, nsteps,
            ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, lr=7e-4,
            alpha=0.99, epsilon=1e-5, total_timesteps=int(80e6), lrschedule='linear'):

        sess = tf_util.get_session()
        self.sess = sess
        nenvs = env.num_envs
        nbatch = nenvs*nsteps


        with tf.variable_scope(model_type, reuse=tf.AUTO_REUSE):
            # step_model is used for sampling
            step_model = policy(nenvs, 1, sess)

            # train_model is used to train our network
            train_model = policy(nbatch, nsteps, sess)

        A = tf.placeholder(train_model.action.dtype, train_model.action.shape)
        ADV = tf.placeholder(tf.float32, [nbatch])
        R = tf.placeholder(tf.float32, [nbatch])
        LR = tf.placeholder(tf.float32, [])

        # Calculate the loss
        # Total loss = Policy gradient loss - entropy * entropy coefficient + Value coefficient * value loss

        # Policy loss
        neglogpac = train_model.pd.neglogp(A)
        # L = A(s,a) * -logpi(a|s)
        pg_loss = tf.reduce_mean(ADV * neglogpac)

        # Entropy is used to improve exploration by limiting the premature convergence to suboptimal policy.
        entropy = tf.reduce_mean(train_model.pd.entropy())

        # Value loss
        vf_loss = losses.mean_squared_error(tf.squeeze(train_model.vf), R)

        loss = pg_loss - entropy*ent_coef + vf_loss * vf_coef

        # Update parameters using loss
        # 1. Get the model parameters
        params = find_trainable_variables(model_type)

        # 2. Calculate the gradients
        grads = tf.gradients(loss, params)
        if max_grad_norm is not None:
            # Clip the gradients (normalize)
            grads, grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
        grads = list(zip(grads, params))
        # zip aggregate each gradient with parameters associated
        # For instance zip(ABCD, xyza) => Ax, By, Cz, Da

        # 3. Make op for one policy and value update step of A2C
        trainer = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon)

        _train = trainer.apply_gradients(grads)

        lr = Scheduler(v=lr, nvalues=total_timesteps, schedule=lrschedule)

        def train(obs, states, rewards, masks, actions, values):
            # Here we calculate advantage A(s,a) = R + yV(s') - V(s)
            # rewards = R + yV(s')
            advs = rewards - values
            for step in range(len(obs)):
                cur_lr = lr.value()

            td_map = {train_model.X:obs, A:actions, ADV:advs, R:rewards, LR:cur_lr}
            if states is not None:
                td_map[train_model.S] = states
                td_map[train_model.M] = masks
            policy_loss, value_loss, policy_entropy, _ = sess.run(
                [pg_loss, vf_loss, entropy, _train],
                td_map
            )
            return policy_loss, value_loss, policy_entropy

        def _step(observation, **kwargs):
            step_model_p = tf.nn.softmax(step_model.pi)
            return step_model._evaluate([step_model.action, step_model_p, step_model.state, step_model.q, step_model.vf, step_model.neglogp], observation, **kwargs)


        self.train = train
        self.train_model = train_model
        self.step_model = step_model
        self.step = step_model.step
        self._step = _step
        self.value = step_model.value
        self.initial_state = step_model.initial_state
        self.save = functools.partial(tf_util.save_variables, sess=sess)
        self.load = functools.partial(tf_util.load_variables, sess=sess)
        tf.global_variables_initializer().run(session=sess)


def learn(
    args,
    extra_args,
    #q_exp,
    #q_model,
    uni_tstart,
    network,
    suffix=None,
    nsteps=5,#12,
    vf_coef=0.5,
    ent_coef=0.01,
    max_grad_norm=0.5,
    lr=7e-4,
    lrschedule='linear',
    epsilon=1e-5,
    alpha=0.99,
    gamma=0.99,
    log_interval=10,
    load_path=None):

    '''
    Main entrypoint for A2C algorithm. Train a policy with given network architecture on a given environment using a2c algorithm.

    Parameters:
    -----------

    network:            policy network architecture. Either string (mlp, lstm, lnlstm, cnn_lstm, cnn, cnn_small, conv_only - see baselines.common/models.py for full list)
                        specifying the standard network architecture, or a function that takes tensorflow tensor as input and returns
                        tuple (output_tensor, extra_feed) where output tensor is the last network layer output, extra_feed is None for feed-forward
                        neural nets, and extra_feed is a dictionary describing how to feed state into the network for recurrent neural nets.
                        See baselines.common/policies.py/lstm for more details on using recurrent nets in policies


    env:                RL environment. Should implement interface similar to VecEnv (baselines.common/vec_env) or be wrapped with DummyVecEnv (baselines.common/vec_env/dummy_vec_env.py)


    seed:               seed to make random number sequence in the alorightm reproducible. By default is None which means seed from system noise generator (not reproducible)

    nsteps:             int, number of steps of the vectorized environment per update (i.e. batch size is nsteps * nenv where
                        nenv is number of environment copies simulated in parallel)

    total_timesteps:    int, total number of timesteps to train on (default: 80M)

    vf_coef:            float, coefficient in front of value function loss in the total loss function (default: 0.5)

    ent_coef:           float, coeffictiant in front of the policy entropy in the total loss function (default: 0.01)

    max_gradient_norm:  float, gradient is clipped to have global L2 norm no more than this value (default: 0.5)

    lr:                 float, learning rate for RMSProp (current implementation has RMSProp hardcoded in) (default: 7e-4)

    lrschedule:         schedule of learning rate. Can be 'linear', 'constant', or a function [0..1] -> [0..1] that takes fraction of the training progress as input and
                        returns fraction of the learning rate (specified as lr) as output

    epsilon:            float, RMSProp epsilon (stabilizes square root computation in denominator of RMSProp update) (default: 1e-5)

    alpha:              float, RMSProp decay parameter (default: 0.99)

    gamma:              float, reward discounting parameter (default: 0.99)

    log_interval:       int, specifies how frequently the logs are printed out (default: 100)

    **network_kwargs:   keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network
                        For instance, 'mlp' network architecture has arguments num_hidden and num_layers.

    '''
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()
    #size = 0
 
    logger = configure(args.log_path, log_suffix='-a2c'+suffix)

    env = build_env(args)
    if args.save_video_interval != 0:
        env = VecVideoRecorder(env, osp.join(logger.get_dir(), "videos"), record_video_trigger=lambda x: x % args.save_video_interval == 0, video_length=args.save_video_length)

    total_timesteps = int(args.num_timesteps)
    seed = args.seed

    set_global_seeds(seed)

    env_type, env_id = get_env_type(args)

    # Get the nb of env
    nenvs = env.num_envs
    #print("a2c nenvs: " + str(nenvs))
    policy = build_policy(env, network)

    # Instantiate the model object (that creates step_model and train_model)
    model = Model(model_type='a2c_model', policy=policy, env=env, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
        max_grad_norm=max_grad_norm, lr=lr, alpha=alpha, epsilon=epsilon, total_timesteps=total_timesteps, lrschedule=lrschedule)
    if load_path is not None:
        model.load(load_path)

    model_ppo2 = Model(model_type='ppo2_model', policy=policy, env=env, nsteps=512, ent_coef=ent_coef, vf_coef=vf_coef,
        max_grad_norm=max_grad_norm, lr=lr, alpha=alpha, epsilon=epsilon, total_timesteps=total_timesteps, lrschedule=lrschedule)
    model_acer = Model(model_type='acer_model', policy=policy, env=env, nsteps=21, ent_coef=ent_coef, vf_coef=vf_coef,
        max_grad_norm=max_grad_norm, lr=lr, alpha=alpha, epsilon=epsilon, total_timesteps=total_timesteps, lrschedule=lrschedule)

    # Instantiate the runner object
    runner = Runner(env, model, model_ppo2, model_acer, nsteps=nsteps, gamma=gamma)
    epinfobuf = deque(maxlen=100)

    # Calculate the batch_size
    nbatch = nenvs*nsteps

    # Start total timer
    tstart = time.time()

    rewmean = [None,None,None]
    #buf = bytearray(1 << 30)  # receive buffer
    #buf = memoryview(numpy.arange(1 << 70).tobytes())
    buf = [numpy.zeros((8, 8, 4, 32), dtype='double'), numpy.zeros((1, 32, 1, 1), dtype='double'), numpy.zeros((4, 4, 32, 64), dtype='double'), numpy.zeros((1, 64, 1, 1), dtype='double'), numpy.zeros((3, 3, 64, 64), dtype='double'), numpy.zeros((1, 64, 1, 1), dtype='double'), numpy.zeros((3136, 512), dtype='double'), numpy.zeros((512,), dtype='double'), numpy.zeros((512, 6), dtype='double'), numpy.zeros((6,), dtype='double'), numpy.zeros((512, 6), dtype='double'), numpy.zeros((6,), dtype='double'), numpy.zeros((512, 1), dtype='double'), numpy.zeros((1,), dtype='double')]
    bufs = numpy.zeros(int(1), dtype='double')

    Params_cur = [[] for i in range(size)]
    Params_pre = [[] for i in range(size)]
    Arr_index = [0 for i in range(size)]

    for update in range(1, total_timesteps//nbatch+1):
        #print("a2c update: " + str(update))
        #gc.collect()

        #while not q_exp[0][1].empty():
        #    rewmean[1] = q_exp[0][1].get()
        #while not q_exp[0][2].empty():
        #    rewmean[2] = q_exp[0][2].get()

        #rew tag=0, model tag=1
        #buf[:] = b'\x00' * len(buf)
        bufs[:] = 0
        req = comm.Irecv(bufs, source=1, tag=0)
        while True:
            if not req.Get_status():
                time.sleep(0.4) 
                print('BEFORE wait!')
                status = req.Test() #wait() 
                print('AFTER wait!')
            else:
                status = req.Test()

            if status:
                print('a2c before rew 1')
                rewmean[1] = bufs[0] #status[1][0]
                print('a2c after rew 1')
                del status
                gc.collect()
                #buf[:] = b'\x00' * len(buf)
                bufs[:] = 0
                req = comm.Irecv(bufs, source=1, tag=0)
            else:
                req.Cancel()
                req.Free()
                break

        #buf[:] = b'\x00' * len(buf)
        bufs[:] = 0
        req = comm.Irecv(bufs, source=2, tag=0)
        while True:
            if not req.Get_status():
                time.sleep(0.4)
                print('BEFORE wait!')
                status = req.Test() #wait() 
                print('AFTER wait!')
            else:
                status = req.Test()

            if status:
                print('a2c before rew 2')
                rewmean[2] = bufs[0] #status[1][0]
                print('a2c after rew 2')
                del status
                gc.collect()
                #buf[:] = b'\x00' * len(buf)
                bufs[:] = 0
                req = comm.Irecv(bufs, source=2, tag=0)
            else:
                req.Cancel()
                req.Free()
                break

        # Get mini batch of experiences
        ret = runner.run(rewmean, Arr_index, Params_cur, Params_pre)
        for i in range(len(ret)):
            obs, states, rewards, masks, actions, values, epinfos = ret[i]#runner.run()
            if epinfos is not None:
                epinfobuf.extend(epinfos)
            #print("a2c train i: " + str(i))
            policy_loss, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)

        if update % 10 == 0: #1000 == 0:
            params = find_trainable_variables("a2c_model")
        #for var in params:
        #    print(var)
            param_val = model.sess.run(params)
            print('a2c model length: ' + str(len(param_val)))
            for j in range(size):
                if j != rank:
                    #pass
                    for m in range(len(param_val)):
                        buf[m] = param_val[m]
                        comm.Isend(buf[m], dest=j, tag=1)

            #q_model[1][0].put(param_val)
            #q_model[2][0].put(param_val)
            a2crewmean = safemean([epinfo['r'] for epinfo in epinfobuf])
            #buf[:] = 0
            bufs[0] = a2crewmean
            if not math.isnan(a2crewmean):
                for k in range(size):
                    if k != rank:
                        #pass
                        comm.Isend(bufs, dest=k, tag=0)
                #q_exp[1][0].put(a2crewmean)
                #q_exp[2][0].put(a2crewmean)
                rewmean[0] = a2crewmean

        nseconds = time.time()-tstart
        wallclock_time = time.time() - uni_tstart

        # Calculate the fps (frame per second)
        fps = int((update*nbatch)/nseconds)
        if update % log_interval == 0 or update == 1:
            print("a2c update: " + str(update))
            #print("a2c nbatch: " + str(nbatch))
            # Calculates if value function is a good predicator of the returns (ev > 1)
            # or if it's just worse than predicting nothing (ev =< 0)
            ev = explained_variance(values, rewards)
            logger.logkv("nupdates", update)
            logger.logkv("total_timesteps", update*nbatch)
            logger.logkv("fps", fps)
            logger.logkv("policy_entropy", float(policy_entropy))
            logger.logkv("value_loss", float(value_loss))
            logger.logkv("explained_variance", float(ev))
            logger.logkv("eprewmean", safemean([epinfo['r'] for epinfo in epinfobuf]))
            logger.logkv("eplenmean", safemean([epinfo['l'] for epinfo in epinfobuf]))
            logger.logkv("wallclock_time", wallclock_time)
            logger.dumpkvs()
    return model

